diff --git a/.github/workflows/doc.yaml b/.github/workflows/doc.yaml new file mode 100644 index 00000000..fba9b693 --- /dev/null +++ b/.github/workflows/doc.yaml @@ -0,0 +1,62 @@ +name: Deploy MkDocs to GitHub Pages + +on: + push: + branches: + - main + workflow_dispatch: + +permissions: + contents: read + pages: write + id-token: write + +concurrency: + group: pages + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Cache pip dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('docs/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -r docs/requirements.txt + + - name: Build documentation + run: mkdocs build --clean + + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: ./site + + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml new file mode 100644 index 00000000..3b4eae9c --- /dev/null +++ b/.github/workflows/docker.yaml @@ -0,0 +1,67 @@ +# +name: Create and publish a Docker image + +on: + release: + types: [published] + + workflow_dispatch: + inputs: + ref: + description: 'The git ref (branch or tag) to build the Docker image from.' + required: true + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-push-image: + runs-on: self-hosted + timeout-minutes: 240 # wait up to 4 hours + # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. + permissions: + contents: read + packages: write + attestations: write + id-token: write + steps: + - name: Checkout repository + uses: actions/checkout@v5 + # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. + - name: Log in to the Container registry + uses: docker/login-action@v2 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels. + - name: Extract metadata (tags, labels) for Docker + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=match,pattern=\d.\d.\d + type=sha + # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. + # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see [Usage](https://github.com/docker/build-push-action#usage) in the README of the `docker/build-push-action` repository. + # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. + - name: Build and push Docker image + id: push + uses: docker/build-push-action@v4 + with: + context: . + push: true + file: scripts/docker/dockerfile + shm-size: 64g + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + + # This step generates an artifact attestation for the image, which is an unforgeable statement about where and how it was built. It increases supply chain security for people who consume the image. For more information, see [Using artifact attestations to establish provenance for builds](/actions/security-guides/using-artifact-attestations-to-establish-provenance-for-builds). + - name: Generate artifact attestation + uses: actions/attest-build-provenance@v3 + with: + subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} + subject-digest: ${{ steps.push.outputs.digest }} + push-to-registry: true diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml new file mode 100644 index 00000000..8b08b505 --- /dev/null +++ b/.github/workflows/docker/docker-compose.yaml @@ -0,0 +1,26 @@ +services: + + ajet-node-1: + image: agentjet-unittest:latest + pull_policy: never + command: bash -c "uv pip install -e .[trinity] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block" + environment: + - PATH=/opt/venv/bin:$PATH + - HF_ENDPOINT=https://hf-mirror.com + - RAY_ADDRESS=auto + working_dir: /workspace + volumes: + - ajet-volume:/mnt + - ../../..:/workspace + shm_size: "64G" + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0', '1', '2', '3'] + capabilities: [gpu] + +volumes: + ajet-volume: + external: true diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 00000000..af12d999 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,13 @@ +name: pre-commit + +on: [push, pull_request] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml new file mode 100644 index 00000000..fa2a7df5 --- /dev/null +++ b/.github/workflows/unittest.yaml @@ -0,0 +1,111 @@ +name: unittest + +on: + issue_comment: + types: [created] + +permissions: + contents: write + checks: write + pull-requests: write + +jobs: + unittest: + # only run on pull request + if: ${{ github.event.issue.pull_request && (startsWith(github.event.comment.body, '/unittest')) && github.event.comment.author_association == 'COLLABORATOR' }} + runs-on: self-hosted + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + path: ajet-${{ github.run_id }} + ref: refs/pull/${{ github.event.issue.number }}/head + + - name: Setup docker compose + working-directory: ajet-${{ github.run_id }}/.github/workflows/docker + run: | + docker compose up -d + sleep 15s + + - name: Check ray status + working-directory: ajet-${{ github.run_id }}/.github/workflows/docker + run: | + MAX_RETRIES=20 + RETRY_INTERVAL=5 + for i in $(seq 1 $MAX_RETRIES); do + docker compose exec ajet-node-1 ray status && docker compose exec ajet-node-2 ray status && break + echo "Waiting for ray cluster to be ready... ($i/$MAX_RETRIES)" + sleep $RETRY_INTERVAL + if [ "$i" -eq "$MAX_RETRIES" ]; then + echo "Ray cluster failed to start after $MAX_RETRIES retries." + exit 1 + fi + done + + - name: Decide test type + id: test_type + working-directory: ajet-${{ github.run_id }} + run: | + COMMENT="${{ github.event.comment.body }}" + if [[ "$COMMENT" == "/unittest"* ]]; then + echo "type=all" >> $GITHUB_OUTPUT + fi + + - name: Run unittest + working-directory: ajet-${{ github.run_id }}/.github/workflows/docker + run: | + TYPE="${{ steps.test_type.outputs.type }}" + if [ "$TYPE" = "all" ]; then + echo "tests_run=true" >> $GITHUB_ENV + docker compose exec ajet-node-1 pytest tests -v -s --ctrf report.json + fi + + - name: Convert report.json time to ms + working-directory: ajet-${{ github.run_id }} + if: env.tests_run == 'true' || failure() + run: | + REPORT=report.json + if [ -f "$REPORT" ]; then + jq '(.results.tests[] | .duration, .start, .stop) |= (. * 1000) | (.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT" + fi + + - name: Clean checkpoint dir + working-directory: ajet-${{ github.run_id }}/.github/workflows/docker + if: always() + run: | + docker compose exec ajet-node-1 rm -rf /mnt/checkpoints/* + continue-on-error: true + + - name: Upload test results + if: env.tests_run == 'true' || failure() + uses: actions/upload-artifact@v4 + with: + name: pytest-results + path: ajet-${{ github.run_id }}/report.json + continue-on-error: true + + - name: Publish Test Report + if: env.tests_run == 'true' || failure() + uses: ctrf-io/github-test-reporter@v1 + with: + report-path: ajet-${{ github.run_id }}/report.json + summary: true + pull-request: false + issue: ${{ github.event.issue.number }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + continue-on-error: true + + - name: Remove docker compose + working-directory: ajet-${{ github.run_id }}/.github/workflows/docker + if: always() + run: | + docker compose down --remove-orphans + continue-on-error: true + + - name: Cleanup workspace + if: always() + run: | + rm -rf ajet-${{ github.run_id }} 2>/dev/null + continue-on-error: true diff --git a/.gitignore b/.gitignore index 128696a9..c16d08c4 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,17 @@ launcher_record .trash trinity_checkpoints +good_trinity_checkpoints +_resources +auto/auto_grader.json +tutorial/example_math_agent/math_agent_debug.yaml +saved_experiments +tests/temp +vsdb.py +appworld_pack_v2.tar* +saved_checkpoints +data +datasets +tutorial2 +site +dump.rdb diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 12e7e9d8..00000000 --- a/.gitmodules +++ /dev/null @@ -1,9 +0,0 @@ -[submodule "external/envservice"] - path = external/envservice - url = http://gitlab.alibaba-inc.com/EconML/EnvService.git -[submodule "external/verl"] - path = external/verl - url = https://github.com/volcengine/verl.git -[submodule "external/experiencemaker"] - path = external/experiencemaker - url = http://gitlab.alibaba-inc.com/OpenRepo/ExperienceMaker.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 72e099f2..33bde514 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,8 +1,55 @@ repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.11.4" + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 hooks: - - id: ruff - args: ["--fix", "--show-fixes", "--output-format=full"] - exclude: ^.*\.(ipynb)$ - - id: ruff-format + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: detect-private-key + + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black + language_version: python3.10 + args: [--line-length=999999] + + # - repo: https://github.com/pycqa/isort + # rev: 5.12.0 + # hooks: + # - id: isort + # args: ["--profile", "black", "--filter-files"] + + # - repo: https://github.com/pycqa/flake8 + # rev: 6.1.0 + # hooks: + # - id: flake8 + # additional_dependencies: [flake8-docstrings] + # args: [ + # "--max-line-length=100", + # "--max-complexity=20", + # "--select=C,E,F,W,B,B950", + # "--ignore=E203,E266,E501,W503", + # ] + + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.7.0 + # hooks: + # - id: mypy + # args: [ + # --ignore-missing-imports, + # --disable-error-code=var-annotated, + # --disable-error-code=union-attr, + # --disable-error-code=no-redef, + # --disable-error-code=assignment, + # --disable-error-code=has-type, + # --disable-error-code=attr-defined, + # --disable-error-code=import-untyped, + # --disable-error-code=truthy-function, + # --follow-imports=skip, + # --explicit-package-bases, + # ] diff --git a/README.md b/README.md index 5e940356..b01cff25 100644 --- a/README.md +++ b/README.md @@ -1,165 +1,155 @@ -# AgentScope Tune +# AgentJet (Beta) -AgentScope Tune, or **ASTune**, is an advanced agent training framework for tuning AgentScope workflow and agent(s). +[![Benchmarking](https://img.shields.io/badge/Benchmarking-0078D4?style=for-the-badge&logo=github)](https://benchmark.agent-matrix.com/) +[![Docs](https://img.shields.io/badge/Docs-Read%20the%20Documents-0A7ECC?style=for-the-badge&logo=readthedocs&logoColor=white)](https://modelscope.github.io/AgentJet) +[![License](https://img.shields.io/badge/License-Apache--2.0-4c1?style=for-the-badge)](LICENSE) +[![Python](https://img.shields.io/badge/Python-3.10+-3776AB?style=for-the-badge&logo=python&logoColor=white)](https://modelscope.github.io/AgentJet/en/installation#requirements) +
+ + AgentJet + +
-## Installation -You can choose between `Trinity training backbone` and `Verl training backbone`. We recommend using `uv` to setup the dependencies and `conda` also works. +**AgentJet (AJet)** is a cutting-edge, user-friendly training framework designed to optimize agents and workflows (built with OpenAI SDK, AgentScope, Langchain, or just HTTP requests), fine-tuning language model weights behind the scenes. -1. Trinity backbone (Option 1) +Simply provide your agent **workflow**, training **dataset**, and **reward** function, and **AgentJet** will be ready to enhance your agents to their optimal performance! -```bash -# Create virtual environment -uv venv --python=3.10.16 -source .venv/bin/activate -git clone https://github.com/binary-husky/Trinity-RFT external/trinity -# Install dependencies -uv pip install --upgrade pip setuptools packaging -i https://mirrors.aliyun.com/pypi/simple/ -uv pip install -r requirements_trinity.txt -i https://mirrors.aliyun.com/pypi/simple/ --no-deps --prerelease=allow -uv pip install -e external/trinity -i https://mirrors.aliyun.com/pypi/simple/ --no-deps -# Install flash attention (must be installed last) -uv pip install --verbose flash-attn ring-flash-attn -i https://mirrors.aliyun.com/pypi/simple/ --no-deps --no-build-isolation -``` +## ✈️ Minimum Example +Let's begin with the simplest example: a math agent with a tool call. -2. VERL Backbone (Option 2) +- First, please check out the [installation guide](https://modelscope.github.io/AgentJet/en/installation/) to set up the training environment. +- Then, tune your first model using the minimum example. + ```python + ajet --conf tutorial/example_math_agent/math_agent.yaml --backbone='verl' -```bash -# Create virtual environment -uv venv --python=3.10.16 -source .venv/bin/activate -git clone https://github.com/binary-husky/verl.git external/verl + # change to --backbone='trinity' if you want to switch to trinity training engine; + # or --backbone='debug' if you want to debug with only vLLM + ``` -# Install dependencies -uv pip install --upgrade pip setuptools packaging -i https://mirrors.aliyun.com/pypi/simple/ -uv pip install -r requirements_verl.txt -i https://mirrors.aliyun.com/pypi/simple/ --no-deps --prerelease=allow -uv pip install -e external/verl -i https://mirrors.aliyun.com/pypi/simple/ --no-deps -# Install flash attention (must be installed last) -uv pip install --verbose flash-attn ring-flash-attn -i https://mirrors.aliyun.com/pypi/simple/ --no-deps --no-build-isolation -``` +## ✈️ Features -注意:二者不能同时安装 -```bash -# verl -> trinity -cd external/verl && uv pip uninstall . && cd ../.. -# trinity -> verl -uv pip install -e external/verl -i https://mirrors.aliyun.com/pypi/simple/ --no-deps -``` +We aim to build a easy-to-learn Agent tuner that unlock more possibilities for agent developers: + +- **Easy and Friendly**. AgentJet helps you tune models behind your agent workflows easily, optimizing your agents for top performance with minimal effort. +- **Rich Tutorial Library**. AgentJet provides a rich library of [examples](https://github.com/modelscope/AgentJet/tree/main/tutorial) as tutorials. +- **Efficient and Scalable**. AgentJet uses [verl] as the default backbone (`--backbone=verl`). However, we also support [trinity](https://github.com/modelscope/Trinity-RFT/) as alternative backbone, accelerating your tuning process via fully asynchronous RFT. +- **Flexible and Fast**. AgentJet supports [multi-agent workflows](https://modelscope.github.io/AgentJet/en/workflow/) and adopts a context merging technique, accelerating training by 1.5x to 10x when the workflow involves multi-turn (or multi-agent) conversations. +- **Reliability and Reproducibility**. Our team keeps track of framework performance across multiple [tasks + major-git-version + training-backbones](https://benchmark.agent-matrix.com/) (under construction, still gathering data, coming soon). + +For advanced researchers, AgentJet also provides high-resolution logging and debugging solutions: + + +- **High-Resolution Logging**: AgentJet allows users to save and inspect token-level rollout details, recording token IDs, token loss masks, and even token logprobs to facilitate workflow development and agent diagnostics. +- **Fast Debugging**: AgentJet also provides the `--backbone=debug` option for the best debugging experience, shortening your wait period from minutes to seconds after code changes and enabling breakpoint debugging in IDEs. + +--- + +### ✈️ Quick Start + +#### Installation + +- **Click here to read the** [**installation guide**](https://modelscope.github.io/AgentJet/en/installation/). + +#### Run Training + +- You can start training your first agent with a single command using a pre-configured YAML file. Take the [Math agent](https://modelscope.github.io/AgentJet/en/example_math_agent/) as an example: + + ```bash + ajet --conf tutorial/example_math_agent/math_agent.yaml + ``` + +#### Example Library + +Explore our rich library of examples to kickstart your journey: + +- 🔢 [**Training a math agent that can write python code**](https://modelscope.github.io/AgentJet/en/example_math_agent). +- 📱 [**Creating an AppWorld agent using AgentScope and training it**](https://modelscope.github.io/AgentJet/en/example_app_world). +- 🐺 [**Developing Werewolves RPG agents and training them**](https://modelscope.github.io/AgentJet/en/example_werewolves). +- 👩🏻‍⚕️ [**Learning to ask questions like a doctor**](https://modelscope.github.io/AgentJet/en/example_learning_to_ask). +- 🎴 [**Writing a countdown game using AgentScope and solving it**](https://modelscope.github.io/AgentJet/en/example_countdown). +- 🚶 [**Solving a frozen lake walking puzzle using AgentJet**](https://modelscope.github.io/AgentJet/en/example_frozenlake). + + +--- + +### ✈️ Core Concepts + +AgentJet makes agent fine-tuning straightforward by separating the developer interface from the internal execution logic.
- 项目架构图 +image +
-## Get Started +#### 1. The User-Centric Interface + +To optimize an agent, you provide three core inputs: + +* [**Trainable Workflow**](https://modelscope.github.io/AgentJet/en/workflow): Define your agent logic by inheriting the Workflow class, supporting both simple agent setups and advanced multi-agent collaborations. +* [**Task Reader**](https://modelscope.github.io/AgentJet/en/data_pipeline): Load training tasks from JSONL files, HuggingFace datasets, interactive environments, or auto-generate them from documents. +* [**Task Judger**](https://modelscope.github.io/AgentJet/en/task_judger): Evaluates agent outputs and assigns rewards to guide training. + +#### 2. Internal System Architecture + +The internal system orchestrates several specialized modules to handle the complexities of RL training and agent interactions. + +* **Launcher**: Manages background service processes (Ray, vLLM) and routes the backbone. +* **Task Reader**: Handles data ingestion, augmentation, and filtering. +* **Task Rollout**: Bridges LLM engines and manages the Gym environment lifecycle. +* **Task Runner**: Executes the Agent workflow and calculates rewards. +* **Model Tuner**: Forwards inference requests from the workflow to the LLM engine. +* **Context Tracker**: Monitors LLM calls and automatically merges shared-history timelines to improve training efficiency by **1.5x to 10x**. + + + + +### ✈️ Navigation + +* **Tutorials**: From [Installation](https://modelscope.github.io/AgentJet/en/installation) to [Tuning your first agent](https://modelscope.github.io/AgentJet/en/tune_your_first_agent) — the essential path for beginners. +* **Core Components**: Define your [Trainable Workflow](https://modelscope.github.io/AgentJet/en/workflow) and manage [Data](https://modelscope.github.io/AgentJet/en/data_pipeline) and [Reward](https://modelscope.github.io/AgentJet/en/task_judger). +* **Example**: Check the [Example Library](https://modelscope.github.io/AgentJet/#example-library) above for real-world cases like [Math](https://modelscope.github.io/AgentJet/en/example_math_agent), [Werewolves game](https://modelscope.github.io/AgentJet/en/example_werewolves) and [Learning to ask task](https://modelscope.github.io/AgentJet/en/example_learning_to_ask). +* **Deep Dive**: Master advanced [Configuration](https://modelscope.github.io/AgentJet/en/configuration). -本节仅内部沟通使用,后期重写。 +## ✈️ Roadmap -项目提供一个多功能launcher用于调试和训练,借助launcher,只需要修改一个`--backbone`参数,就选择任意训练框架启动训练 or 调试。 +AgentJet is a constantly evolving project. We are planning to add the following features in the near future. -1. 使用launcher进行全链路调试(--backbone='debug'):脱离trinity和verl,只与vllm(自动创建)连接,进行调试 - ```bash - # (训练math agent demo)建议开始前杀死所有ray、env_service进程 (python launcher.py --kill="python|ray|vllm|VLLM" && ray stop) - clear && \ - python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --backbone='debug' --with-logview - - # (训练appworld demo)建议开始前杀死所有ray、env_service进程 (python launcher.py --kill="python|ray|vllm|VLLM" && ray stop) - clear && \ - python launcher.py --with-appworld --conf launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml --backbone='debug' --with-logview - ``` -备注:当--backbone=debug时,程序不再使用ray,可以编写vscode的launch.json进行便捷的断点调试,launch.json的配置见文档最后 - - -2. 使用launcher进行训练:使用trinity进行训练 - ```bash - # 建议开始前杀死所有ray、vllm、env_service进程 (python launcher.py --kill="python|ray|vllm|VLLM" && ray stop) - clear && \ - python launcher.py --with-appworld --conf launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml --with-ray --backbone='trinity' - - python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --with-ray --backbone='trinity' - ``` -备注:如果需要断点调试,请添加参数 `python launcher.py --db='TAG1|TAG2|TAG3' --conf=...`,并在代码中需要断点的地方标记一行特殊代码 `from vsdb import bp; bp("TAG1")` 即可。(需要配合Ray Distributed Debugger VSCode Extension) - - -3. 使用launcher进行训练:使用verl进行训练 - ```bash - # 建议开始前杀死所有ray、vllm、env_service进程 (python launcher.py --kill="python|ray|vllm|VLLM" && ray stop) - clear && \ - python launcher.py --with-appworld --conf launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml --backbone='verl' - - python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --backbone='verl' - ``` -备注:如果需要断点调试,请添加参数 `python launcher.py --db='TAG4|TAG5|TAG6' --conf=...`,并在代码中需要断点的地方标记一行特殊代码 `from vsdb import bp; bp("TAG4")` 即可。 - - -# 简要架构 - -1. 读取task(对应配置字段 astune.task_reader) - - astune/task_reader/task_reader_base.py - - class::TaskReaderEnvService - - class::TaskReaderJsonl - - class::TaskReaderHuggingFace - -2. 定义 AgentScopeWorkflow(对应配置字段 astune.rollout.agentscope_learn_protocol ) - - tutorial/appworld.py - - tutorial/math_agent.py - -3. 定义评分函数(对应配置字段 astune.task_judge.judge_protocol ) - - astune/task_judge/judge_base.py - - astune/task_judge/env_service_as_judge.py - - class::EnvServiceJudge - - astune/task_judge/math_answer_as_judge.py - - class::MathAnswerAsJudge - - class::MathAnswerAndLlmAsJudge - -4. 指定模型(对应配置字段 astune.model.path ) - -5. 配置系统(完善中,先凑合着用一下) - - 默认配置 - - astune/default_config/default.yaml ( 存储verl的默认训练配置,可以被 --conf 指定的 yaml 以同名配置的形式覆盖 ) - - astune/default_config/trinity_default.yaml ( 存储trinity默认配置,但可以被 --conf 指定的 yaml 以 trinity.xxx 的形式覆盖 ) - - 配置自动对齐(定义哪些参数需要自动对齐到verl或者trinity上) - - astune/default_config/config_auto_convertion_verl.json - - astune/default_config/config_auto_convertion_trinity.json - -6. ASTune 和 AgentScope 交互系统 V0.5 - - astune/context_manager/cmt_agentscope.py 负责 - - 处理AgentScope生成的Token - - 缓存judge给出评分需要的各种数据(包括但不限于所有对话message,env_service句柄,从astune.task_reader读取的task metadata等) - - 提供LLM的桥接 - - 负责合并timeline - -# note - -FlashInfer? - -clear && killer VLLM && killer ray && killer python && python launcher.py --with-appworld --conf launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml --with-ray --backbone='verl' - -clear && killer VLLM && killer ray && killer python && python launcher.py --with-appworld --conf launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml --with-ray --backbone='verl' - - -- `launche.json` for vscode debugging -```json -{ - - "version": "0.2.0", - "configurations": [ - { - "name": "Python Debugger: Launch rollout", - "type": "debugpy", - "request": "launch", - "program": "launcher.py", - "console": "integratedTerminal", - "args": [ - "--backbone", "debug", - "--conf", "xxxx/xxxx/xxxx.yaml" - ], - "env": { - } - }, - ] +| Category | Feature | Status | +| :--- | :--- | :--- | +| **Examples** | Covering LangGraph and AutoGen frameworks | Done & Verifying | +| **Examples** | Add LoRA training examples | Todo | +| **Infra** | Cross-process Tuner wrapper to pass though process forking | Done & Verifying | +| **Infra** | Optimize configurations for long-context adaptation on smaller GPUs | In Progress | +| **Capability** | Prompt tuning | In Progress | +| **Capability** | Multi-modal training support | Todo | +| **Capability** | MARL Credit assignment | Todo | +| **Capability** | Training dataset generation from few-shot samples | Done & Verifying | + + +## ✈️ Citation + +If you use AgentJet in your research, please cite: + +```bibtex +@software{ + title = {AgentJet: A Cutting-Edge Multi-Agent Training Platform for Large Language Models.}, + author = {The AgentJet Team}, + url = {https://modelscope.github.io/AgentJet/}, + month = {01}, + year = {2026} } -``` \ No newline at end of file +``` + +
+ +--- +
+ +[⭐ Star Us](https://github.com/modelscope/AgentJet) · [Report Bug](https://github.com/modelscope/AgentJet/issues) · [Request Feature](https://github.com/modelscope/AgentJet/issues) +
diff --git a/ajet/__init__.py b/ajet/__init__.py new file mode 100644 index 00000000..12a07cc1 --- /dev/null +++ b/ajet/__init__.py @@ -0,0 +1,9 @@ +from ajet.copilot.job import AgentJetJob +from ajet.schema.task import WorkflowOutput, WorkflowTask +from ajet.tuner import AjetTuner +from ajet.workflow import Workflow +from ajet.utils.vsdb import vscode_conditional_breakpoint as bp + +__all__ = ["Workflow", "WorkflowTask", "WorkflowOutput", "AjetTuner", "AgentJetJob", "bp"] + +__version__ = "0.1.0" diff --git a/ajet/backbone/__init__.py b/ajet/backbone/__init__.py new file mode 100644 index 00000000..4233c9dd --- /dev/null +++ b/ajet/backbone/__init__.py @@ -0,0 +1,16 @@ +from loguru import logger + +try: + from ajet.backbone.trainer_trinity import ( + AjetTaskReader, + AjetWorkflowWrap, + TrinityRolloutManager, + ) + + __all__ = [ + "TrinityRolloutManager", + "AjetWorkflowWrap", + "AjetTaskReader", + ] +except ImportError: + logger.warning("trinity is not available.") diff --git a/ajet/backbone/main_trinity.py b/ajet/backbone/main_trinity.py new file mode 100644 index 00000000..06dd844f --- /dev/null +++ b/ajet/backbone/main_trinity.py @@ -0,0 +1,63 @@ +import ray +import os +from trinity.cli.launcher import main +from trinity.common.config import Config +from trinity.explorer.explorer import Explorer +from trinity.trainer.trainer import Trainer + +from ajet.utils.config_utils import read_ajet_config_with_cache +from ajet.utils.core_env_vars import get_runtime_env +from ajet.utils.launch_utils import set_loguru_default_color + + +set_loguru_default_color() + + +def get_ajet_config_from_trinity_side(): + yaml_path = os.environ.get("AJET_CONFIG_REDIRECT", None) + if yaml_path is None: + raise ValueError("AJET_CONFIG_REDIRECT is not set in environment variables") + ajet_config = read_ajet_config_with_cache(yaml_path) + return ajet_config + + +def patch_runtime_env_to_get_actor(): + """Patch the classmethod of Explorer and Trainer to pass in the runtime env.""" + ajet_config = get_ajet_config_from_trinity_side() + runtime_env = get_runtime_env(ajet_config, is_trinity=True) + os.environ.update(runtime_env["env_vars"]) + + def patched_explorer_get_actor(cls, config: Config): + return ( + ray.remote(cls) + .options( + name=config.explorer.name, + namespace=ray.get_runtime_context().namespace, + runtime_env=runtime_env, + ) + .remote(config) + ) + + def patched_trainer_get_actor(cls, config: Config): + return ( + ray.remote(cls) + .options( + name=config.trainer.name, + namespace=ray.get_runtime_context().namespace, + runtime_env=runtime_env, + ) + .remote(config) + ) + + Explorer.get_actor = classmethod(patched_explorer_get_actor) + Trainer.get_actor = classmethod(patched_trainer_get_actor) + + if ajet_config.ajet.enable_experimental_interchange_server: + from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server + + start_interchange_server(ajet_config) + + +if __name__ == "__main__": + patch_runtime_env_to_get_actor() + main() diff --git a/astune/main_verl.py b/ajet/backbone/main_verl.py similarity index 72% rename from astune/main_verl.py rename to ajet/backbone/main_verl.py index 5ef66a77..8b2fc57a 100644 --- a/astune/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -15,27 +15,23 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ - -def send_train_message(message: str): - import requests, os # 发送短信汇报训练进程 - assert len(message) < 64, f"Message too long: {(message)}" - if os.getenv("ALIYUN_SMS_SERVICE") and os.getenv("SMS"): - print("尝试发送短信:", message) - try: requests.post(json={"phone_numbers": "18810508767", "server_code": "DLC", "error": message, "error_level": "无"}, url=os.getenv("ALIYUN_SMS_SERVICE", "http://localhost:8000/send-sms"), headers={"Content-Type": "application/json"}) - except Exception as e: print(f"Failed to send sms: {e}") - +import atexit import os import socket + import hydra import ray +from beast_logger import print_dict +from loguru import logger from omegaconf import OmegaConf -from verl.experimental.dataset.sampler import AbstractSampler -from verl.trainer.constants_ppo import get_ppo_ray_runtime_env -from verl.trainer.ppo.ray_trainer import RayPPOTrainer from verl.trainer.ppo.reward import load_reward_manager from verl.utils.device import is_cuda_available -from verl.utils.import_utils import load_extern_type -from beast_logger import register_logger, print_dict + +from ajet.utils.core_env_vars import get_runtime_env +from ajet.utils.launch_utils import set_loguru_default_color + +set_loguru_default_color() + @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) def main(config): @@ -59,42 +55,18 @@ def run_ppo(config) -> None: # Check if Ray is not initialized if not ray.is_initialized(): # this is for local ray cluster - runtime_env={ - "env_vars": - { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "WARN", - "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", - "VLLM_USE_V1": "1", - "SWANLAB_API_KEY": os.getenv("SWANLAB_API_KEY"), - } - } - - if os.getenv("RAY_record_task_actor_creation_sites"): - runtime_env["env_vars"].update({ - "RAY_record_task_actor_creation_sites": os.getenv("RAY_record_task_actor_creation_sites"), - }) - if os.getenv("BEST_LOGGER_WEB_SERVICE_URL"): - runtime_env["env_vars"].update({ - "BEST_LOGGER_WEB_SERVICE_URL": os.getenv("BEST_LOGGER_WEB_SERVICE_URL"), - }) + runtime_env = get_runtime_env(config) print_dict(runtime_env["env_vars"], "runtime_env") ray.init( runtime_env=runtime_env, num_cpus=config.ray_init.num_cpus, ) - import atexit - atexit.register(lambda: send_train_message("注意:训练结束")) # 如果环境变量存在,则在程序结束时发送短信 + atexit.register(lambda: ray.shutdown()) # ray shutdown on exit # Create a remote instance of the TaskRunner class, and # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete - if ( - is_cuda_available - and config.trainer.get("profile_steps") is not None - and len(config.trainer.get("profile_steps", [])) > 0 - ): + if is_cuda_available and config.trainer.get("profile_steps") is not None and len(config.trainer.get("profile_steps", [])) > 0: from verl.utils.import_utils import is_nvtx_available assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" @@ -133,17 +105,17 @@ def run(self, config): from pprint import pprint from omegaconf import OmegaConf - from verl.utils.fs import copy_to_local - print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + logger.info(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") pprint(OmegaConf.to_container(config, resolve=True)) OmegaConf.resolve(config) # Download the checkpoint from HDFS to the local machine. # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on local_path = copy_to_local( - config.astune.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + config.ajet.model.path, + use_shm=config.actor_rollout_ref.model.get("use_shm", False), ) # Instantiate the tokenizer and processor. @@ -158,7 +130,10 @@ def run(self, config): if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: assert config.critic.strategy in {"fsdp", "fsdp2"} from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + from verl.workers.fsdp_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + ) use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") if use_legacy_worker_impl in ["auto", "enable"]: @@ -168,34 +143,28 @@ def run(self, config): from verl.workers.fsdp_workers import CriticWorker elif use_legacy_worker_impl == "disable": from verl.workers.roles import CriticWorker - - print("Using new worker implementation") else: raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") - actor_rollout_cls = ( - AsyncActorRolloutRefWorker - if config.astune.rollout.mode == "async" - else ActorRolloutRefWorker - ) + actor_rollout_cls = AsyncActorRolloutRefWorker ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker - - actor_rollout_cls = ( - AsyncActorRolloutRefWorker - if config.astune.rollout.mode == "async" - else ActorRolloutRefWorker + from verl.workers.megatron_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + CriticWorker, ) + + actor_rollout_cls = AsyncActorRolloutRefWorker ray_worker_group_cls = NVMegatronRayWorkerGroup else: raise NotImplementedError - from astune.backbone_verl.trainer import ResourcePoolManager, Role + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role # Map roles to their corresponding remote worker classes. role_worker_mapping = { @@ -237,26 +206,45 @@ def run(self, config): # Load the reward manager for training and validation. reward_fn = load_reward_manager( - config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + config, + tokenizer, + num_examine=0, + **config.reward_model.get("reward_kwargs", {}), ) val_reward_fn = load_reward_manager( - config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + config, + tokenizer, + num_examine=1, + **config.reward_model.get("reward_kwargs", {}), ) resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) from verl.utils.dataset.rl_dataset import collate_fn - from astune.utils.process_dataset import create_rl_sampler # Create training and validation datasets. - from astune.task_reader.task_reader_base import TaskReaderRouter, task_to_standard_dataset - task_reader = TaskReaderRouter(config) + from ajet.task_reader import ( + RouterTaskReader, + task_to_standard_dataset, + ) + from ajet.utils.process_dataset import create_rl_sampler + + task_reader = RouterTaskReader( + config.ajet.task_reader.type, + config.ajet.task_reader, + ) val_dataset = task_to_standard_dataset(task_reader.get_validation_tasks()) train_dataset = task_to_standard_dataset(task_reader.get_training_tasks()) train_sampler = create_rl_sampler(config.data, train_dataset) - from astune.backbone_verl.trainer import BeyondAgentRayPPOTrainer + from ajet.backbone.trainer_verl import AjetRayPPOTrainer + + if config.ajet.enable_experimental_interchange_server: + from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server + + start_interchange_server(config) + # Initialize the PPO trainer. - trainer = BeyondAgentRayPPOTrainer( + trainer = AjetRayPPOTrainer( config=config, tokenizer=tokenizer, processor=processor, @@ -277,6 +265,5 @@ def run(self, config): trainer.fit() - if __name__ == "__main__": main() diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py new file mode 100644 index 00000000..b7ea7cf9 --- /dev/null +++ b/ajet/backbone/main_vllm.py @@ -0,0 +1,248 @@ +import atexit +import os +import sys +from types import SimpleNamespace + +import hydra +from openai import AsyncOpenAI, OpenAI + +from ajet.backbone.warm_up import warm_up_process +from ajet.task_rollout.native_parallel_worker import VerlRolloutManager +from ajet.utils.launch_utils import set_loguru_default_color +from ajet.schema.logprob import TokenAndProb +from ajet.utils.core_env_vars import get_runtime_env +from loguru import logger + +set_loguru_default_color() + + +class TokenAndProbVllmDebug(TokenAndProb): + def __init__(self, t): + # ChatCompletionTokenLogprob(token='token_id:73594', bytes=[96, 96, 96], logprob=-1.9073468138230965e-06, top_logprobs=[]) + token_id = int(t.token.split("token_id:")[-1]) + logprob = t.logprob + try: + decoded_string = bytes(t.bytes).decode("utf-8") + except Exception: + decoded_string = "" + str(t.bytes) + super().__init__(token_id=token_id, logprob=logprob, decoded_string=decoded_string) + + +class ChatCompletionScheduler: + def __init__(self, url, config): + from transformers import AutoTokenizer + + self.url = url + self.config = config + self.tokenizer = AutoTokenizer.from_pretrained(self.config.ajet.model.path) + self.chat_scheduler = SimpleNamespace( + model_name="dummy-model-name", + weighted_addresses="dummy-weighted-addresses", + completion_callback=SimpleNamespace(tokenizer=self.tokenizer), + ) + + def submit_chat_completions(self, messages, sampling_params, request_id, tools=[]): + client = OpenAI( + base_url=self.url, + api_key="token-abc123", + ) + sampling_params = dict( + n=1, + max_completion_tokens=self.config.ajet.rollout.max_response_length_in_one_turn, + ) + sampling_params["temperature"] = self.config.ajet.rollout.val_kwargs.temperature + sampling_params["top_k"] = self.config.ajet.rollout.val_kwargs.top_k + sampling_params["top_p"] = self.config.ajet.rollout.val_kwargs.top_p + + sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) + + if tools: + completion = client.chat.completions.create( + model=self.config.ajet.model.path, + messages=messages, + tools=tools, + extra_body=sampling_params, + ) + else: + completion = client.chat.completions.create( + model=self.config.ajet.model.path, + messages=messages, + extra_body=sampling_params, + ) + + message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + + # sometimes tool use message has no content field + if "content" not in message: + message["content"] = "" + + messages.append( + { + "role": message["role"], + "request_id": completion.id, + "content": message["content"], + "tool_calls": message.get("tool_calls", None), + "tokens": [TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content], # type: ignore + } + ) + return messages + + async def submit_chat_completions_async(self, messages, sampling_params, request_id, tools=[]): + client = AsyncOpenAI( + base_url=self.url, + api_key="token-abc123", + ) + sampling_params = dict( + n=1, + max_completion_tokens=self.config.ajet.rollout.max_response_length_in_one_turn, + ) + sampling_params["temperature"] = self.config.ajet.rollout.val_kwargs.temperature + sampling_params["top_k"] = self.config.ajet.rollout.val_kwargs.top_k + sampling_params["top_p"] = self.config.ajet.rollout.val_kwargs.top_p + + sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) + + if tools: + completion = await client.chat.completions.create( + model=self.config.ajet.model.path, + messages=messages, + tools=tools, + extra_body=sampling_params, + ) + else: + completion = await client.chat.completions.create( + model=self.config.ajet.model.path, + messages=messages, + extra_body=sampling_params, + ) + + message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + + # sometimes tool use message has no content field + if "content" not in message: + message["content"] = "" + + messages.append( + { + "role": message["role"], + "request_id": completion.id, + "content": message["content"], + "tool_calls": message.get("tool_calls", None), + "tokens": [TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content], # type: ignore + } + ) + return messages + + +def run(config): + from ajet.task_reader import RouterTaskReader + + # --------- fast adjustment for debugging --------- + warm_up_process(config) + max_parallel = config.ajet.debug.debug_max_parallel + n_task = config.ajet.debug.debug_first_n_tasks + vllm_port = config.ajet.debug.debug_vllm_port + + # --------- init --------- + async_rollout_manager = ChatCompletionScheduler(config=config, url=f"http://localhost:{vllm_port}/v1") + parallel_env = VerlRolloutManager( + config=config, + async_rollout_manager=async_rollout_manager, + max_parallel=max_parallel, + max_llm_retries=3, + llm_mode="remote", + tokenizer=async_rollout_manager.tokenizer, + ) + + task_reader = RouterTaskReader( + config.ajet.task_reader.type, + config.ajet.task_reader, + ) + tasks = task_reader.get_validation_tasks() + logger.info(tasks[:n_task]) + ctx_tracker = parallel_env.rollout(tasks=tasks[:n_task], mode="sample", epoch="1") # "sample" or "validate" + _ = parallel_env.to_dataproto(ctx_tracker) + + +@hydra.main( + config_path="ajet/default_config", + config_name="ajet_default", + version_base=None, +) +def main(config): + from omegaconf import OmegaConf + + OmegaConf.resolve(config) + runtime_env = get_runtime_env(config) + os.environ.update(runtime_env["env_vars"]) + # atexit.register(lambda: print("Process exiting, performing cleanup...")) + + if config.ajet.enable_experimental_interchange_server: + from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server + + start_interchange_server(config) + + def companion_launch(): + import torch + + from ajet.utils.smart_daemon import LaunchCommandWhenAbsent + + logger.info("Launching companion process for async LLM server...") + model_path = config.ajet.model.path + tensor_parallel_size = config.ajet.debug.debug_tensor_parallel_size + n_avail_gpus = torch.cuda.device_count() + if tensor_parallel_size > n_avail_gpus: + logger.info(f"Warning: tensor_parallel_size {tensor_parallel_size} is greater than available GPUs {n_avail_gpus}. Setting tensor_parallel_size to {n_avail_gpus}.") + tensor_parallel_size = n_avail_gpus + gpu_memory_utilization = config.actor_rollout_ref.rollout.gpu_memory_utilization + max_num_seqs = config.actor_rollout_ref.rollout.max_num_seqs + max_model_len = config.ajet.rollout.max_model_len + seed = config.ajet.debug.debug_vllm_seed + vllm_port = config.ajet.debug.debug_vllm_port + companion = LaunchCommandWhenAbsent( + full_argument_list=[ + sys.executable, + "-m", + "vllm.entrypoints.cli.main", + "serve", + f"{model_path}", + "--tensor-parallel-size", + f"{tensor_parallel_size}", + "--dtype", + "auto", + "--enforce-eager", + "--gpu-memory-utilization", + f"{gpu_memory_utilization}", + "--disable-custom-all-reduce", + "--max-num-seqs", + f"{max_num_seqs}", + "--max-model-len", + f"{max_model_len}", + "--load-format", + "auto", + "--enable-chunked-prefill", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enable-prefix-caching", + "--seed", + f"{seed}", + "--port", + f"{vllm_port}", + ], + dir="./", + tag="external_vllm_server", + ) + companion.launch( + launch_wait_time=1800, + success_std_string="Application startup complete", + env_dict={**os.environ}, + ) + + companion_launch() + + run(config) + + +if __name__ == "__main__": + main() diff --git a/ajet/backbone/trainer_trinity.py b/ajet/backbone/trainer_trinity.py new file mode 100644 index 00000000..fdbc6d9c --- /dev/null +++ b/ajet/backbone/trainer_trinity.py @@ -0,0 +1,369 @@ +import asyncio +import os +from typing import Dict, List, Literal, Optional, cast + +import datasets +import openai +import swanlab +from loguru import logger +from transformers import AutoTokenizer +from trinity.buffer.reader import READER +from trinity.buffer.reader.file_reader import TaskFileReader, _HFBatchReader +from trinity.buffer.schema import FORMATTER +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows import WORKFLOWS +from trinity.common.workflows.workflow import Task as TrinityTask +from trinity.common.workflows.workflow import Workflow +from trinity.utils.log import get_logger +from trinity.utils.monitor import MONITOR, Monitor + +from ajet.backbone.warm_up import warm_up_process +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) +from ajet.schema.trajectory import Sample +from ajet.task_reader import dict_to_ajet_task +from ajet.task_rollout.native_parallel_worker import DynamicRolloutManager +from ajet.utils.config_utils import read_ajet_config_with_cache +from ajet.utils.testing_utils import _test_if_test_mode + + +def get_ajet_config_from_trinity_side(): + yaml_path = os.environ.get("AJET_CONFIG_REDIRECT", None) + if yaml_path is None: + raise ValueError("AJET_CONFIG_REDIRECT is not set in environment variables") + ajet_config = read_ajet_config_with_cache(yaml_path) + return ajet_config + + +class TrinityRolloutManager(DynamicRolloutManager): + def __init__( + self, + is_eval, + task, + llm_handle, + tokenizer, + config, + llm_mode: Literal["local", "remote", "trinity"] = "trinity", + **kwargs, + ): + self.is_eval = is_eval + self.task = task + self.tokenizer = tokenizer + self.config = config + self.llm_mode = llm_mode + + super().__init__( + config=self.config, + async_rollout_manager=llm_handle, + max_parallel=1, + max_llm_retries=1, + tokenizer=tokenizer, + llm_mode=llm_mode, + **kwargs, + ) + + def convert_task(self, task: TrinityTask): + from ajet.schema.task import Task + + assert isinstance(task.raw_task, dict) + return dict_to_ajet_task(task.raw_task) + + def thread_worker(self): + observation_window = { + "stop": [False], + "step": [0], + "token": [0], + } + ajet_task = self.convert_task(self.task) + return self.rollout_env_worker( + task=ajet_task, + task_batch_index=0, + task_tag=f"T{ajet_task.task_id}#R", + mode="sample" if not self.is_eval else "validate", + task_thread_index=0, + observation_window=observation_window, + ) + + async def run_in_new_thread(self) -> MultiAgentContextTracker: + return cast( + MultiAgentContextTracker, + await asyncio.to_thread(self.thread_worker), + ) + + +@WORKFLOWS.register_module("ajet_workflow") +class AjetWorkflowWrap(Workflow): + is_async: bool = True + + def __init__( + self, + model: ModelWrapper, + task: TrinityTask, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + self.task = task + self.model_client = model.get_openai_async_client() + self.is_eval = task.is_eval + # extract the query and the answer from the task + self.query = task.raw_task.get(task.format_args.prompt_key) # type: ignore [index] + self.answer = task.raw_task.get(task.format_args.response_key) # type: ignore [index] + + async def run_async(self): + ajet_config = get_ajet_config_from_trinity_side() + warm_up_process(ajet_config) + tracker = await TrinityRolloutManager( + is_eval=self.is_eval, + task=self.task, + llm_handle=self.model_client, + tokenizer=AutoTokenizer.from_pretrained(self.model_client.model_path), + config=ajet_config, + ).run_in_new_thread() + + sample_final = [] + try: + sample_arr = tracker.group_tokenize() + except Exception as e: + raise e + finally: + tracker.generate_log(global_step="NA") + sample_final += sample_arr + + exps = [] + for _, sample in enumerate(sample_final): + sample: Sample + input_ids = sample.input_ids + prompt_ids = sample.prompt_ids + response_ids = sample.response_ids + response_loss_mask = sample.response_loss_mask + + logprobs = sample.response_logprobs + reward = sample.step_reward # reward scalar + + metrics = { + "success_rate": tracker.reward_structure.success_rate, + "madness": tracker.reward_structure.madness, + } + + if len(response_ids) + len(prompt_ids) == len(input_ids) and len(logprobs) == len(response_ids) and len(logprobs) > 0: + exp = Experience( + tokens=input_ids, # [seq_length] prompt + response + prompt_length=len(prompt_ids), # Length of the prompt in tokens, used for generating attention masks + logprobs=logprobs, # [resp_length] + reward=reward, # + # advantages=None, + # returns=None, + info={}, + metrics=metrics, # for wandb logging (must be string:float) + response_text="", # optional + prompt_text="", # optional + #### for multi-turn experiences + action_mask=response_loss_mask, # 1 stands for training, 0 stands for ignoring + messages=sample.messages, # + # tools, + #### for dpo experiences + # chosen, + # rejected, + # chosen_messages, + # rejected_messages, + #### for multi-modal data + # multi_modal_inputs + ) + exps += [exp] + else: + logger.exception("Data length mismatch when converting sample to experience.") + return exps + + +try: + + @READER.register_module("ajet") + class AjetTaskReader(TaskFileReader): + def __init__(self, config): + self.config = config + self.read_batch_size = config.batch_size + self.split = config.split + + ajet_config = get_ajet_config_from_trinity_side() + + from ajet.task_reader import ( + RouterTaskReader, + task_to_standard_dataset, + ) + + task_reader = RouterTaskReader( + ajet_config.ajet.task_reader.type, + ajet_config.ajet.task_reader, + ) + + dataset_segments = [] + if "train" in self.split: + dataset_segments.append(task_to_standard_dataset(task_reader.get_training_tasks())) + if "val" in self.split: + dataset_segments.append(task_to_standard_dataset(task_reader.get_validation_tasks())) + if not dataset_segments: + raise ValueError(f"Unsupported split '{self.split}'. Expected to contain 'train' or 'val'.") + + concatenated_dataset = dataset_segments[0] if len(dataset_segments) == 1 else datasets.concatenate_datasets(dataset_segments) + + self.dataset = _HFBatchReader( + concatenated_dataset, + name=self.config.name, + default_batch_size=self.read_batch_size, + total_epochs=self.config.total_epochs if not self.config.is_eval else 1, + offset=self.config.index, + drop_last=not self.config.is_eval, + total_steps=self.config.total_steps, + enable_progress_bar=self.config.enable_progress_bar, + ) + self.formatter = FORMATTER.get("task")(self.config) + + def read(self, batch_size: Optional[int] = None) -> List: + batch_size = batch_size or self.read_batch_size + tasks = [] + samples, indices = self.dataset.read_batch(batch_size) + for sample in samples: + task = self.formatter.format(sample) + tasks.append(task) + return tasks + +except Exception: + pass + + +@MONITOR.register_module("swanlab") +class SwanlabMonitor(Monitor): + """Monitor with SwanLab. + + This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments. + + Supported monitor_args in config.monitor.monitor_args: + - api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env + (SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login. + - workspace (Optional[str]): Organization/username workspace. + - mode (Optional[str]): "cloud" | "local" | "offline" | "disabled". + - logdir (Optional[str]): Local log directory when in local/offline modes. + - experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}". + - description (Optional[str]): Experiment description. + - tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically. + - id (Optional[str]): Resume target run id (21 chars) when using resume modes. + - resume (Optional[Literal['must','allow','never']|bool]): Resume policy. + - reinit (Optional[bool]): Whether to re-init on repeated init() calls. + """ + + def __init__(self, project: str, group: str, name: str, role: str, config) -> None: + assert swanlab is not None, "swanlab is not installed. Please install it to use SwanlabMonitor." + + monitor_args = (config.monitor.monitor_args or {}) if config and getattr(config, "monitor", None) else {} + + # Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`. + api_key = os.environ.get("SWANLAB_API_KEY") + if api_key: + try: + swanlab.login(api_key=api_key, save=True) + except Exception: + # Best-effort login; continue to init which may still work if already logged in + pass + else: + raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.") + + # Compose tags (ensure list and include role/group markers) + tags = monitor_args.get("tags") or [] + if isinstance(tags, tuple): + tags = list(tags) + if role and role not in tags: + tags.append(role) + if group and group not in tags: + tags.append(group) + + # Determine experiment name + exp_name = monitor_args.get("experiment_name") or f"{name}_{role}" + self.exp_name = exp_name + ajet_config = get_ajet_config_from_trinity_side() + + # Prepare init kwargs, passing only non-None values to respect library defaults + init_kwargs = { + "project": project, + "workspace": monitor_args.get("workspace"), + "experiment_name": exp_name, + "description": monitor_args.get("description"), + "tags": tags or None, + "logdir": monitor_args.get("logdir"), + "mode": monitor_args.get("mode") or "cloud", + "settings": monitor_args.get("settings"), + "id": monitor_args.get("id"), + "config": ajet_config, + "resume": monitor_args.get("resume"), + "reinit": monitor_args.get("reinit"), + } + # Strip None values to avoid overriding swanlab defaults + init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None} + + self.logger = swanlab.init(**init_kwargs) + self.console_logger = get_logger(__name__, in_ray_actor=True) + + run_info = self.logger.public.json() + self.data_dashboard_url = run_info["cloud"]["experiment_url"] + + def log_table(self, table_name: str, experiences_table, step: int): + assert swanlab is not None, "swanlab is not installed. Please install it to use SwanlabMonitor." + + # Convert pandas DataFrame to SwanLab ECharts Table + headers: List[str] = list(experiences_table.columns) + # Ensure rows are native Python types + rows: List[List[object]] = experiences_table.astype(object).values.tolist() + try: + tbl = swanlab.echarts.Table() + tbl.add(headers, rows) + swanlab.log({table_name: tbl}, step=step) + except Exception: + # Fallback: log as CSV string if echarts table is unavailable + csv_str = experiences_table.to_csv(index=False) + swanlab.log({table_name: csv_str}, step=step) + + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + # SwanLab doesn't use commit flag; keep signature for compatibility + assert swanlab is not None, "swanlab is not installed. Please install it to use SwanlabMonitor." + swanlab.log(data, step=step) + self.console_logger.info(f"Step {step}: {data}") + + ajet_config = get_ajet_config_from_trinity_side() + experiment_dir = ajet_config.ajet.experiment_dir + trinity_log = f"{experiment_dir}/{self.exp_name}.log" + + with open(trinity_log, "a") as f: + f.write(f"Step {step}: {data}\n") + + if ajet_config.ajet.execute_test: # apply a test probe + if "critic/score/mean" in data: + return + if "experience_pipeline/group_advantages/reward_mean/mean" not in data: + return + test_robot_data = {} + test_robot_data["step"] = step + test_robot_data["data_dashboard_url"] = self.data_dashboard_url + test_robot_data["reward_for_test_robot"] = data["experience_pipeline/group_advantages/reward_mean/mean"] + _test_if_test_mode(key="reward_probe", value=test_robot_data, config=ajet_config) + + def close(self) -> None: + try: + # Prefer run.finish() if available + if hasattr(self, "logger") and hasattr(self.logger, "finish"): + self.logger.finish() + else: + # Fallback to global finish + swanlab.finish() + except Exception: + pass + + @classmethod + def default_args(cls) -> Dict: + """Return default arguments for the monitor.""" + return {} diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py new file mode 100644 index 00000000..7c23f58d --- /dev/null +++ b/ajet/backbone/trainer_verl.py @@ -0,0 +1,972 @@ +# Copyright 2025 Alibaba Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import uuid +from collections import defaultdict +from pprint import pprint +from typing import List, Optional + +import hydra +import numpy as np +import torch +from beast_logger import print_dict +from loguru import logger +from omegaconf import OmegaConf +from tqdm import tqdm +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.config import AlgoConfig +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import ( + RayPPOTrainer, + Role, + apply_kl_penalty, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import reduce_metrics + +from ajet.backbone.warm_up import warm_up_process +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.schema.task import Task +from ajet.task_reader import dict_to_ajet_task +from ajet.task_rollout.native_parallel_worker import VerlRolloutManager +from ajet.utils.metric_helper import save_trajectory_as_json_file, update_metrics + + +def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor: + """ + Compute reward for a batch of data. + Args: + data: DataProto object containing the input data. + return_dict: Whether to return a dictionary or just the reward tensor. + + Returns: + Tensor of shape (bs, response_len) if return_dict is False, + or a dict with 'reward_tensor' and 'reward_extra_info'. + """ + # Within DataFlow, world.execute() will pass a float score, which will be contained in the DataProto.non_tensor_batch('reward_scores') + + # Initialize reward tensor + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) # (bs, reslen) + reward_extra_info = defaultdict(list) + + # Batch-level processing + prompt_ids_batch = data.batch["prompts"] # (bs, prompt_len) + prompt_lengths = prompt_ids_batch.shape[-1] + + # Get attention masks for all items + attention_masks = data.batch["attention_mask"] # (bs, total_len) + response_lengths = attention_masks[:, prompt_lengths:].sum(dim=1) # (bs, ) + + # Get reward scores + reward_scores_list = [item for item in data.non_tensor_batch["reward_scores"]] + reward_scores = torch.tensor(reward_scores_list, device=reward_tensor.device, dtype=torch.float32) # (bs, ) + + # Use advanced indexing to assign rewards (placing reward at the last token position) + reward_tensor[torch.arange(len(data)), response_lengths - 1] = reward_scores + + if return_dict: + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } + else: + return reward_tensor + + +def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto): + """ + Union the gen_batch_output with the batch based on task_id. + """ + map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)} + gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"] + indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids] + batch_extend = batch.select_idxs(indices) + batch_final = batch_extend.union(gen_batch_output) + return batch_final + + +def compute_advantage( + data: DataProto, + adv_estimator: AdvantageEstimator, + gamma: float = 1.0, + lam: float = 1.0, + num_repeat: int = 1, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> DataProto: + """Compute advantage estimates for policy optimization. + + This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. + The advantage estimates are used to guide policy optimization in RL algorithms. + + Args: + data (DataProto): The data containing batched model outputs and inputs. + adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). + gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. + lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. + num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. + norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in + GRPO. Defaults to True. + config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. + + Returns: + DataProto: The updated data with computed advantages and returns. + """ + # Back-compatible with trainers that do not compute response mask in fit + if "response_mask" not in data.batch.keys(): + data.batch["response_mask"] = compute_response_mask(data) + # prepare response group + if adv_estimator == AdvantageEstimator.GAE: + # Compute advantages and returns using Generalized Advantage Estimation (GAE) + advantages, returns = core_algos.compute_gae_advantage_return( + token_level_rewards=data.batch["token_level_rewards"], + values=data.batch["values"], + response_mask=data.batch["response_mask"], + gamma=gamma, + lam=lam, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + if config.get("use_pf_ppo", False): + data = core_algos.compute_pf_ppo_reweight_data( + data, + config.pf_ppo.get("reweight_method"), + config.pf_ppo.get("weight_pow"), + ) + elif adv_estimator == AdvantageEstimator.GRPO: + # Initialize the mask for GRPO calculation + grpo_calculation_mask = data.batch["response_mask"] + # If multi-turn, replace the mask with the relevant part of loss_mask + # Get length from the initial response mask + response_length = grpo_calculation_mask.size(1) + # This mask is the one intended for GRPO + grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] + # Call compute_grpo_outcome_advantage with parameters matching its definition + advantages, returns = core_algos.compute_grpo_outcome_advantage( + token_level_rewards=data.batch["token_level_rewards"], + response_mask=grpo_calculation_mask, + index=data.non_tensor_batch["uid"], + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + else: + # handle all other adv estimator type other than GAE and GRPO + adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) + adv_kwargs = { + "token_level_rewards": data.batch["token_level_rewards"], + "response_mask": data.batch["response_mask"], + "config": config, + } + if "uid" in data.non_tensor_batch: # optional + adv_kwargs["index"] = data.non_tensor_batch["uid"] + if "reward_baselines" in data.batch: # optional + adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] + + # calculate advantage estimator + advantages, returns = adv_estimator_fn(**adv_kwargs) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + return data + + +class AjetRayPPOTrainer(RayPPOTrainer): + """Distributed PPO trainer using Ray for scalable reinforcement learning. + Slightly modified from RayPPOTrainer in verl. + """ + + # ####################################### + # init + # ####################################### + def _validate_config(self): + config = self.config + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + minimal_bsz = n_gpus + + # 1. Check total batch size for data correctness + real_train_batch_size = config.ajet.data.train_batch_size * config.ajet.rollout.num_repeat + assert real_train_batch_size % minimal_bsz == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " f"({minimal_bsz})" + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + """Validate mutually exclusive micro batch size configuration options. + + Ensures that users don't set both deprecated micro_batch_size and + the new micro_batch_size_per_gpu parameters simultaneously. + + Args: + mbs: Deprecated micro batch size parameter value. + mbs_per_gpu: New micro batch size per GPU parameter value. + name (str): Configuration section name for error messages. + + Raises: + ValueError: If both parameters are set or neither is set. + """ + settings = { + "reward_model": "micro_batch_size", + "actor_rollout_ref.ref": "log_prob_micro_batch_size", + "actor_rollout_ref.rollout": "log_prob_micro_batch_size", + } + + if name in settings: + param = settings[name] + param_per_gpu = f"{param}_per_gpu" + + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove " f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated).") + + # Actor validation done in ActorConfig.__post_init__ and validate() + try: + actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor) + actor_config.validate( + n_gpus, + config.ajet.data.train_batch_size, + config.actor_rollout_ref.model, + ) + except hydra.errors.InstantiationException: + raise ValueError("You are using an unsupported VERL version. Please read `documents/backbones.md`") + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + if self.use_reference_policy: + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.ajet.rollout.log_prob_micro_batch_size, + config.ajet.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) + + # Check for reward model micro-batch size conflicts + if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: + check_mutually_exclusive( + config.reward_model.micro_batch_size, + config.reward_model.micro_batch_size_per_gpu, + "reward_model", + ) + + if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + logger.warning("NOTICE: You have both enabled in-reward kl and kl loss.") + + # critic + if self.use_critic: + critic_config = omega_conf_to_dataclass(config.critic) + critic_config.validate(n_gpus, config.ajet.data.train_batch_size) + + if config.data.get("val_batch_size", None) is not None: + logger.warning("WARNING: val_batch_size is deprecated." + " Validation datasets are sent to inference engines as a whole batch," + " which will schedule the memory themselves.") + + # check eval config + if config.ajet.rollout.val_kwargs.do_sample: + assert config.ajet.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample" + + logger.success("[validate_config] All configuration checks passed successfully!") + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + profile_option=self.config.trainer.npu_profile.options, + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role="ref", + profile_option=self.config.trainer.npu_profile.options, + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], + config=self.config.reward_model, + ) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.trainer, "profile_steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") + assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, "worker_nsight_options must be set when profile_steps is set" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(OmegaConf.select(self.config.trainer, "worker_nsight_options")) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + from verl.experimental.agent_loop.agent_loop import ( + AgentLoopManager, + AsyncLLMServerManager, + ) + + self.async_rollout_mode = True + agent_loop_manager = AgentLoopManager( + config=self.config, + worker_group=self.actor_rollout_wg, + ) + self.async_server_list = agent_loop_manager.async_llm_servers + self.async_rollout_manager = AsyncLLMServerManager(self.config, self.async_server_list) + + self.reward_fn = parse_reward_from_dataproto + self.val_reward_fn = parse_reward_from_dataproto + + self.parallel_env = VerlRolloutManager( + config=self.config, + async_rollout_manager=self.async_rollout_manager, + max_parallel=self.config.ajet.rollout.max_env_worker, + tokenizer=self.tokenizer, + ) + + # ####################################### + # training loop + # ####################################### + def fit(self): # noqa: C901 + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + from verl.utils.tracking import Tracking + + warm_up_process(self.config) + self.verl_logger = Tracking( + project_name=self.config.ajet.project_name, + experiment_name=self.config.ajet.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # wake and sleep to enforce param sync + self.async_rollout_manager.wake_up() + self.async_rollout_manager.sleep() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + self.verl_logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm( + total=self.total_training_steps, + initial=self.global_steps, + desc="Training Progress", + ) + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False + next_step_profile = False + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling(not prev_step_profile and curr_step_profile if self.config.trainer.profile_continuous_steps else curr_step_profile) + + batch_dict["index"] = torch.tensor( + [i for i in range(len(batch_dict["task_id"]))], + dtype=torch.long, + ) + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object, + ) + + # # pop those keys for generation + batch_keys_to_pop = ["index"] + non_tensor_batch_keys_to_pop = [ + "task_id", + "main_query", + "env_type", + "metadata", + "init_messages", + ] + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + logger.info("=== + rollout step begin ===") + with marked_timer("gen", timing_raw, color="red"): + assert self.async_rollout_mode + logger.info("=== wake up begin ===") + self.async_rollout_manager.wake_up() + logger.info("=== wake up end ===") + tasks: List[Task] = [ + dict_to_ajet_task( + dict( + task_id=gen_batch.non_tensor_batch["task_id"][i], + main_query=gen_batch.non_tensor_batch["main_query"][i], + env_type=gen_batch.non_tensor_batch["env_type"][i], + metadata=gen_batch.non_tensor_batch["metadata"][i], + init_messages=gen_batch.non_tensor_batch["init_messages"][i], + ) + ) + for i in range(len(gen_batch)) + ] + logger.info(str([gen_batch.non_tensor_batch["task_id"][i] for i in range(len(gen_batch))])) + logger.info("=" * 10 + "start fit rollout" + "=" * 10) + self.parallel_env.current_global_steps = self.global_steps + context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout(tasks, mode="sample", epoch=f"train.{epoch}") + logger.info("=" * 10 + "end fit rollout" + "=" * 10) + logger.info("begin to convert context_tracker_arr to dataproto") + gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) + logger.info("end convertion") + + success_rate = [traj.reward_structure.success_rate for traj in context_tracker_arr] + madness_rate = [traj.reward_structure.madness for traj in context_tracker_arr] + # reward = [traj.reward_structure.raw_reward for traj in context_tracker_arr] + round_cnt = [traj.round_cnt for traj in context_tracker_arr] + metrics.update( + { + "critic/round_cnt": np.mean(round_cnt), + "critic/madness_rate": np.mean(madness_rate), + "critic/success_rate": np.mean(success_rate), + "critic/real_success_rate": np.mean(context_tracker_arr[0].current_batch_success_rate), + "critic/real_reward": np.mean(context_tracker_arr[0].current_batch_reward), + } + ) + save_trajectory_as_json_file(context_tracker_arr, self.global_steps, self.config, prefix="train") + update_metrics(context_tracker_arr, metrics) + if self.config.ajet.execute_test: # apply a test probe + from swanlab.data.run.main import get_run + + from ajet.utils.testing_utils import ( + _test_if_test_mode, + ) + + run_info = get_run().public.json() # type: ignore + data = { + "step": self.global_steps, + "reward_for_test_robot": metrics["critic/real_reward"], + "data_dashboard_url": run_info["cloud"]["experiment_url"], + } + _test_if_test_mode(key="reward_probe", value=data, config=self.config) + + logger.info(f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}") + self.async_rollout_manager.sleep() + logger.info("=== - rollout step end ===") + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + raise NotImplementedError("REMAX is not supported in GRPO yet.") + + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object, + ) + batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output) + batch.batch["response_mask"] = compute_response_mask(batch) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + raise NotImplementedError("launch_reward_fn_async is not supported in GRPO yet.") + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # recompute old_log_probs + logger.info("=== + compute log_probs begin ===") + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_loss = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=loss_agg_mode, + ) + assert not torch.isnan(entropy_loss).item(), "Entropy loss should not be NaN, something must have gone terribly wrong." + old_log_prob_metrics = {"actor/entropy": entropy_loss.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, + kl_ctrl=self.kl_ctrl_in_reward, + kl_penalty=self.config.algorithm.kl_penalty, + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.ajet.rollout.num_repeat, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = True + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # validate + if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and (is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration): + if esi_close_to_expiration: + logger.info("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = self.global_steps + 1 in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False + self._stop_profiling(curr_step_profile and not next_step_profile if self.config.trainer.profile_continuous_steps else curr_step_profile) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + self.verl_logger.log(data=metrics, step=self.global_steps) + progress_bar.update(1) + self.global_steps += 1 + + # # when enabled oai request interchange, we need to clear the cache from time to time + # if self.config.ajet.enable_experimental_interchange_server: + # from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear + # ensure_dat_interchange_server_cache_clear() + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # ####################################### + # Validate + # ####################################### + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_outputs = [] + sample_scores = [] + sample_turns = [] + + for test_data in self.val_dataloader: + test_data["index"] = torch.tensor([i for i in range(len(test_data["task_id"]))], dtype=torch.long) + test_batch = DataProto.from_single_dict(test_data) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.ajet.rollout.val_kwargs.num_repeat, + interleave=True, + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + batch_keys_to_pop = ["index"] + non_tensor_batch_keys_to_pop = [ + "task_id", + "main_query", + "env_type", + "metadata", + "init_messages", + ] + if "multi_modal_data" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + if "agent_name" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("agent_name") + if "extras" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("extras") + + test_gen_batch = test_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.ajet.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + logger.info(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + self.async_rollout_manager.wake_up() + main_val_dataset = self.get_eval_dataset() + + logger.info("=" * 10 + "start validate rollout" + "=" * 10) + context_tracker_arr, tasks, val_metrics = self.eval_dataset( + target_dataset=main_val_dataset, + target_dataset_name="main_val_dataset", + mode="validate", + epoch="test.1", + ) + logger.info("=" * 10 + "end validate rollout" + "=" * 10) + test_output_gen_batch = self.parallel_env.to_dataproto(context_tracker_arr) + self.async_rollout_manager.sleep() + logger.info("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], + dtype=object, + ) + tasks = tasks[: len(main_val_dataset)] + test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch) + # test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # evaluate using reward_function + if self.val_reward_fn is None: + raise ValueError("val_reward_fn must be provided for validation.") + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + logger.info(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + logger.info(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + break # hack to escape the loop after one batch + + metric_dict = val_metrics + + return metric_dict + + def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): + """ + Evaluate a dataset by running rollouts and computing task completion metrics. + + Args: + target_dataset: The dataset to evaluate + target_dataset_name: Name for logging purposes + mode: Evaluation mode ("sample" or "validate") + epoch: Current epoch for logging + + Returns: + Tuple of (ctx_trackers, tasks) containing trajectory results and task definitions + """ + pass_n = self.config.ajet.trainer_common.val_pass_n + + tasks = [] + for _ in range(pass_n): + tasks += [task for task in target_dataset] + + ctx_trackers = self.parallel_env.rollout(tasks=tasks, mode=mode, epoch=epoch) # "sample" or "validate" + task_results = {} + for ctx_tracker in ctx_trackers: + reward = ctx_tracker.reward_structure.raw_reward + task_id = ctx_tracker.task_id + if task_id not in task_results: + task_results[task_id] = {} + task_results[task_id]["reward_arr"] = [] + task_results[task_id]["tag_arr"] = [] + if reward >= 1: + ctx_tracker.tag = "success" + elif reward == 0: + ctx_tracker.tag = "failure" + else: + ctx_tracker.tag = "half_success" + task_results[task_id]["tag_arr"] += [ctx_tracker.tag] + task_results[task_id]["reward_arr"] += [ctx_tracker.reward_structure.raw_reward] + task_results[task_id]["scenario"] = task_id.split("_")[0] + + repeated_success_tasks = 0 + num_all_success_tasks = 0 # number of tasks that is successful among all n attempts + num_pass_n_tasks = 0 # number of tasks that is successful at least once among n attempts + for task_id, task_outcomes in task_results.items(): + # Calculate num_all_success_tasks # The number of tasks where all were successful in n experiments + # Calculate num_pass_n_tasks # The number of tasks where at least one was successful in n experiments + assert len(task_outcomes["tag_arr"]) == pass_n + if all(tag == "success" for tag in task_outcomes["tag_arr"]): + num_all_success_tasks += 1 + if any(tag == "success" for tag in task_outcomes["tag_arr"]): + num_pass_n_tasks += 1 + repeated_success_tasks += task_outcomes["tag_arr"].count("success") + + # record logs + for ctx_tracker in ctx_trackers: + ctx_tracker.generate_log() + + rewards = [ctx_tracker.reward_structure.raw_reward for ctx_tracker in ctx_trackers] + num_tasks = len(task_results) + assert num_tasks == len(ctx_trackers) // pass_n + + val_metrics = { + "target dataset name": target_dataset_name, + "pass_n": pass_n, + "total_tasks": len(task_results), + "num_all_success_tasks": num_all_success_tasks, + f"num_pass_n_tasks(pass@{pass_n})": num_pass_n_tasks, + "TGC@1": repeated_success_tasks / (num_tasks * pass_n), + f"TGC@{pass_n}": num_pass_n_tasks / num_tasks, + f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks, + "mean_reward": sum(rewards) / len(rewards) if rewards else 0, + } + save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval") + update_metrics(ctx_trackers, val_metrics) + print_dict( + val_metrics, + narrow=True, + header=target_dataset_name, + mod="evaluation", + ) + + self.verl_logger.log(data=val_metrics, step=self.global_steps) + + return ctx_trackers, tasks, val_metrics + + def get_eval_dataset(self): + from ajet.task_reader import RouterTaskReader + + task_reader = RouterTaskReader( + self.config.ajet.task_reader.type, + self.config.ajet.task_reader, + ) + tasks = task_reader.get_validation_tasks() + self.main_val_dataset = tasks + return self.main_val_dataset diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py new file mode 100644 index 00000000..e7db56c1 --- /dev/null +++ b/ajet/backbone/warm_up.py @@ -0,0 +1,105 @@ +""" +Process level warm up +""" + + +import asyncio +import logging +import os +from ajet.utils.async_utils import apply_httpx_aclose_patch + +apply_httpx_aclose_patch() + + +def init_parallel_rollout_logger(experiment_name): + """Initialize the logger with the given configuration.""" + if "PROCESS_LEVEL_WARMUP_INIT_LOGGER" in os.environ: + return + os.environ["PROCESS_LEVEL_WARMUP_INIT_LOGGER"] = "1" + + from datetime import datetime + + from beast_logger import register_logger + + final_log_path = os.path.join( + "saved_experiments", + experiment_name, + datetime.now().strftime("%Y_%m_%d_%H_%M"), + # machine host name + os.uname().nodename, + ) + os.environ["BEST_LOGGER_PATH"] = final_log_path + non_console_mods = ["rollout", "token_clip", "bad_case"] + register_logger( + mods=["evaluation", "exception", "benchmark"], + non_console_mods=non_console_mods, + auto_clean_mods=[], + base_log_path=final_log_path, + debug=False, + ) + + target_logger = logging.getLogger("vllm.entrypoints.openai.tool_parsers.hermes_tool_parser") + target_logger.setLevel(logging.CRITICAL) + logging.getLogger("httpx").setLevel(logging.WARNING) + + +def warm_up_task_judge_when_needed(config): + if config.ajet.task_judge.judge_type == "rubrics_auto_grader": + from ajet.task_judge.rm_auto_grader_judge import AutoGraderJudge + + judge = AutoGraderJudge(config) + asyncio.run(judge.generate_rubrics_from_samples()) + asyncio.run(judge.load_rubrics_from_cache()) + + +def clean_up_tmp_ajet_dir(config): + """Clean up old IPC socket files in /tmp/ajet directory.""" + import time + + if config.ajet.enable_experimental_interchange_server is False: + return + + tmp_dir = "/tmp/ajet" + if not os.path.exists(tmp_dir): + return + current_time = time.time() + ttl = 4 * 3600 + try: + for filename in os.listdir(tmp_dir): + if not filename.endswith(".sock"): + continue + + file_path = os.path.join(tmp_dir, filename) + try: + print(current_time - os.path.getmtime(file_path)) + if current_time - os.path.getmtime(file_path) > ttl: + os.remove(file_path) + except OSError: + pass + except OSError: + pass + + +def warm_up_process(config): + """ + Process level warm up + This will not be called multiple when: + - multi-threading + - forked multi-processing + This may be called multiple times when: + - spawned multi-processing + - ray remote actor + + --- + + Note: Skipping process level warm up will not cause significant issues, but may lead to + slightly longer initialization times for certain components in each process. + """ + + if "PROCESS_LEVEL_WARMUP_INIT" in os.environ: + return + os.environ["PROCESS_LEVEL_WARMUP_INIT"] = "1" + experiment_name = config.ajet.experiment_name + init_parallel_rollout_logger(experiment_name) + warm_up_task_judge_when_needed(config) + clean_up_tmp_ajet_dir(config) diff --git a/astune/__init__.py b/ajet/context_tracker/__init__.py similarity index 100% rename from astune/__init__.py rename to ajet/context_tracker/__init__.py diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py new file mode 100644 index 00000000..7bd50dc5 --- /dev/null +++ b/ajet/context_tracker/base_tracker.py @@ -0,0 +1,143 @@ +from typing import List, Tuple, Union +from typing import List, Union, Tuple, Dict, Optional, Any +from ajet.schema.task import WorkflowTask + +from ajet.schema.extended_msg import ( + INVALID_LOG_PROB_VALUE, + ExtendedMessage, + find_sublist_indices, +) +from ajet.schema.trajectory import Reward + + +def replace_token_ids( + token_container, + precise_token, + precise_logprob, + begin_ids, + end_ids, +) -> Tuple[List[int], List[int], List[int], bool]: + """ + Replace token ids + + input || token_container: [begin_ids, ... ids_that_may_not_precise ... , end_ids, other_ids] + ==> + output1 || final_token_ids: [begin_ids, ... precise_token ... , end_ids, other_ids] + output2 || final_logprob: [NA, ... precise_logprob ... , NA , NA ] + output3 || loss_mask: [0, ... 1 ... , 1 , 0 ] + output4 || lack_normal_eos: False + + + + test case: + ----------- case 1 (with_normal_eos) ----------- + (NA = INVALID_LOG_PROB_VALUE) + + begin_ids = [151644, 77091, 198] + end_ids = [151645] + token_container = [151644, 77091, 198, 1, 1, 1, 151645, 1, 2, 3, 4] + precise_token = [4, 3, 2, 151645] + precise_logprob = [-0.4, -0.3, -0.2, -0.1] + + assert replace_token_ids( + token_container, + precise_token, + precise_logprob, + begin_ids, + end_ids, + ) = ( + [151644, 77091, 198, 4, 3, 2, 151645, 1, 2, 3, 4] + [NA, NA, NA, -0.4, -0.3, -0.2, -0.1, NA, NA, NA, NA] + [0, 0, 0, 1, 1, 1, 1, 0 ,0 ,0 ,0] + ) + + ----------- case 2 (lack_normal_eos) ----------- + begin_ids = [151644, 77091, 198] + end_ids = [151645] + token_container = [151644, 77091, 198, 1, 1, 1, 151645, 1, 2, 3, 4] + precise_token = [3, 2, 1,] + precise_logprob = [-0.3, -0.2, -0.1] + + assert replace_token_ids( + token_container, + precise_token, + precise_logprob, + begin_ids, + end_ids, + ) = ( + [151644, 77091, 198, 3, 2, 1, 151645, 1, 2, 3, 4] + [NA, NA, NA, -0.3, -0.2, -0.1, NA, NA, NA, NA] + [0, 0, 0, 1, 1, 1, 0, 0 ,0 ,0 ,0] + ) + + """ + + _begin_index = find_sublist_indices(token_container, begin_ids) + len(begin_ids) + _end_index = find_sublist_indices(token_container, end_ids, reverse=True) + + if precise_token[-len(end_ids) :] == end_ids: # remove end_ids token + lack_normal_eos = False + precise_token_center = precise_token[: -len(end_ids)] + precise_logprob_center = precise_logprob[: -len(end_ids)] + logprob_eos_tail = precise_logprob[-len(end_ids) :] + else: + lack_normal_eos = True + precise_token_center = precise_token + precise_logprob_center = precise_logprob + logprob_eos_tail = [] + + if precise_token[: len(begin_ids)] == begin_ids: # remove begin_ids token + # precise_token = precise_token[len(begin_ids) :] + # precise_logprob_center = precise_logprob[len(begin_ids) :] + raise ValueError("Unexpected situation, wrong llm output (unexpected BOS): please post an github issue.") + + final_token_ids = token_container[:_begin_index] + precise_token_center + token_container[_end_index:] + final_logprob = [INVALID_LOG_PROB_VALUE] * _begin_index + precise_logprob_center + logprob_eos_tail + [INVALID_LOG_PROB_VALUE] * (len(token_container) - _end_index - len(logprob_eos_tail)) + loss_mask = [0] * _begin_index + [1] * len(precise_logprob_center) + [1] * len(logprob_eos_tail) + [0] * (len(token_container) - _end_index - len(logprob_eos_tail)) + return final_token_ids, final_logprob, loss_mask, lack_normal_eos + + +class BaseTracker(object): + def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): + self.workflow_task = workflow_task + self.task_batch_index = self.workflow_task.task_batch_index + self.task_tag = self.workflow_task.task_tag + self.task_id = self.workflow_task.task_id + self.episode_uuid = self.workflow_task.episode_uuid + + self.config = config + self.tokenizer = tokenizer + self.saved_timelines: List[List[ExtendedMessage]] = [] + self.current_context_status = "" + max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn + max_model_len: int = self.config.ajet.rollout.max_model_len + self.max_seq_length: int = max_model_len - max_response_length + self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") + self._im_start_token_id = tokenizer.encode("<|im_start|>")[0] + self.generated_token_cnt = 0 + self.terminal_rewards_dict = {} + self.discarded = False + self.is_terminated = False + self.reward_structure: Union[Reward, None] = None + self.context_time_cost = 0 + self.tag = "" + self.current_batch_success_rate: float = float("-inf") + self.current_batch_reward: float = float("-inf") + self.already_mad_flag: bool = False + self.round_cnt = 0 + self.generation_prompt_token = None + self.log_metrics: Optional[Dict[str, Union[float, List[float]]]] = None # Initialize workflow_metadata to store tool statistics + + assert self.config.ajet.data.max_prompt_length + self.config.ajet.data.max_response_length <= max_model_len + + def group_tokenize(self): + raise NotImplementedError + + def group_tokenize_multi_group(self): + raise NotImplementedError + + def group_tokenize_single_group(self, timeline): + raise NotImplementedError + + def tokenize_steps(self, ext_steps: List[ExtendedMessage], index: int, total_steps: int) -> dict: + raise NotImplementedError diff --git a/ajet/context_tracker/basic_tracker.py b/ajet/context_tracker/basic_tracker.py new file mode 100644 index 00000000..8a759c92 --- /dev/null +++ b/ajet/context_tracker/basic_tracker.py @@ -0,0 +1,408 @@ +import torch +import copy +from collections import defaultdict +from typing import List, Tuple +from loguru import logger + +from ajet.context_tracker.base_tracker import ( + BaseTracker, + ExtendedMessage, + replace_token_ids, +) +from ajet.schema.trajectory import Reward, Sample +from ajet.utils.tokenizer import ajet_apply_chat_template + + +class BaseContextTracker(BaseTracker): + """ + A linear context tracker template that handles the conversation flow between LLM and environment. + This class manages the context window, tokenization, and message history in a linear fashion. + + Attributes: + config: Configuration object containing environment and model settings + tokenizer: Tokenizer instance for processing text + full_context (List[ExtendedMessage]): List of all messages in the conversation + current_context_status (str): Current status of the context + max_seq_length (int): Maximum sequence length for the context window + terminal_rewards_dict (dict): Dictionary storing terminal rewards + """ + + def __init__(self, config, tokenizer, **kwargs): + super().__init__(config, tokenizer, **kwargs) + self.generation_prompt_token = self.get_generation_prompt_token() + + def remove_last_non_llm_msg(self, ext_msg_list: List[ExtendedMessage]): + if len(ext_msg_list) > 0: + if ext_msg_list[-1].author != "llm": + ext_msg_list.pop(-1) + return ext_msg_list + + def get_inc(self, text_frag_from, text_frag_to): + """ + Get the incremental token array from text_frag_from to text_frag_to. + """ + tokenizer_output = self.tokenizer(text_frag_from, return_tensors="pt", padding=False) + tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist() # type: ignore + token_ids_acc = tokenizer_input_ids + + tokenizer_output = self.tokenizer(text_frag_to, return_tensors="pt", padding=False) + input_ids = tokenizer_output["input_ids"][0].tolist() # type: ignore + input_id_increment = input_ids[len(token_ids_acc) :] # get the new tokens added in this step + overlap_length = 0 + for i in range(len(token_ids_acc)): + if i < len(token_ids_acc) and input_ids[i] == token_ids_acc[i]: + overlap_length += 1 + else: + break + msg = f"previous token length: {len(token_ids_acc)}, overlap token length: {(overlap_length)}, increment token length: {len(input_id_increment)}" + return input_id_increment, msg + + # generate token + def get_token_inc_from_llm_response(self, input_msg_ref, llm_output, tools: List[dict] = []) -> Tuple[List[int], List[int], List[int], bool]: + llm_output_role_content = { + "role": llm_output["role"], + "content": llm_output["content"], + } + if llm_output.get("tool_calls", None): + llm_output_role_content.update({"tool_calls": llm_output.get("tool_calls", [])}) + + # completion_token_arr will contain generation_prompt header + completion_token_arr, _ = self.get_inc( + ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=input_msg_ref, + tokenize=False, + tools=tools, + add_generation_prompt=False, + ), + ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=input_msg_ref + [llm_output_role_content], + tokenize=False, + tools=tools, + add_generation_prompt=False, + ), + ) + vllm_output_raw_token = [t.token_id for t in llm_output["tokens"]] + vllm_output_raw_logprob = [t.logprob for t in llm_output["tokens"]] + self.generated_token_cnt += len(vllm_output_raw_token) + if not self.generation_prompt_token: + self.generation_prompt_token = self.get_generation_prompt_token() + final_token_arr, token_logprob_arr, loss_mask, lack_normal_eos = replace_token_ids( + token_container=completion_token_arr, + precise_token=vllm_output_raw_token, + precise_logprob=vllm_output_raw_logprob, + begin_ids=self.generation_prompt_token, + end_ids=[self.tokenizer.eos_token_id], + ) + return final_token_arr, token_logprob_arr, loss_mask, lack_normal_eos + + def filter_context_via_author(self, timeline, author: str) -> List[ExtendedMessage]: + return copy.deepcopy([c for c in timeline if c.author == author]) + + def filter_context_via_authors(self, timeline, authors: List[str]) -> List[ExtendedMessage]: + return copy.deepcopy([c for c in timeline if c.author in authors]) + + def filter_context_via_authors_with_limit(self, timeline, authors: List[str], limit: dict) -> List[ExtendedMessage]: + """ + limit = { + "llm": "keep_last@2" + "env": "keep_first@2" + } + """ + filtered_via_authors = copy.deepcopy([c for c in timeline if c.author in authors]) + for limit_author, limit_item in limit.items(): + limit_item_command, limit_item_value = limit_item.split("@") + if limit_item_command == "keep_last": + limit_item_value = int(limit_item_value) + # remove all message whose author is `llm_author` except the last `limit_item_value` messages + num_need_rm = len([c for c in filtered_via_authors if c.author == limit_author]) - limit_item_value + if num_need_rm > 0: + num_already_rm = 0 + filtered_via_authors_new = [] + for c in filtered_via_authors: + if c.author == limit_author: + num_already_rm += 1 + if num_already_rm <= num_need_rm: + continue + filtered_via_authors_new += [c] + filtered_via_authors = filtered_via_authors_new + + elif limit_item_command == "keep_first": + limit_item_value = int(limit_item_value) + # remove all message whose author is `llm_author` except the first `limit_item_value` messages + num_need_keep = len([c for c in filtered_via_authors if c.author == limit_author]) - limit_item_value + if num_need_keep > 0: + num_already_keep = 0 + filtered_via_authors_new = [] + for c in filtered_via_authors: + if c.author == limit_author: + num_already_keep += 1 + if num_already_keep > limit_item_value: + continue + filtered_via_authors_new += [c] + filtered_via_authors = filtered_via_authors_new + + else: + raise ValueError(f"Unknown limit_item_command {limit_item_command} in filter_context_via_authors_with_limit") + return filtered_via_authors + + def compute_step_level_reward(self, index: int, total_steps: int) -> float: + # TODO: support multi-step reward + assert self.reward_structure is not None + + # --------------- global level reward --------------- + global_reward = self.reward_structure.raw_reward + gamma = self.config.ajet.rollout.gamma + step_reward_base = global_reward * (gamma ** (total_steps - index - 1)) + assert gamma == 1.0, "Currently only support gamma == 1.0, we'll support multi-step reward in the future" + + # --------------- compute step level reward --------------- + step_reward = step_reward_base # reward scalar + if self.already_mad_flag: + step_reward = self.config.ajet.rollout.agent_madness_reward + self.reward_structure.madness = -1.0 + + return step_reward + + def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List: + result = [] + for ext_msg in ext_msg_array: + d = { + "role": ext_msg.role, + "content": ext_msg.content_for_future, + } + if ext_msg.tool_calls: + d.update({"tool_calls": ext_msg.tool_calls}) + if ext_msg.tool_call_id: + d.update({"tool_call_id": ext_msg.tool_call_id}) + result.append(d) + return result + + def group_tokenize_single_group(self, timeline): + sample_arr = [] + ext_steps = timeline + tracker_tokenized = self.tokenize_steps(ext_steps=ext_steps, index=0, total_steps=1) + sample = Sample( + tracker_tokenized=tracker_tokenized, + messages=self.to_role_content(ext_steps), + config=self.config, + task_batch_index=self.task_batch_index, + task_tag=self.task_tag, + task_id=self.task_id, + ) + sample.truncate_output_ids() + sample_arr += [sample] + return sample_arr + + def group_tokenize_multi_group(self): + sample_arr = [] + max_num_group = self.config.ajet.rollout.multi_turn.max_sample_per_task + for index, ext_steps in enumerate(self.saved_timelines): + tracker_tokenized = self.tokenize_steps( + ext_steps=ext_steps, + index=index, + total_steps=len(self.saved_timelines), + ) + sample = Sample( + tracker_tokenized=tracker_tokenized, + messages=self.to_role_content(ext_steps), + config=self.config, + task_batch_index=self.task_batch_index, + task_tag=self.task_tag, + task_id=self.task_id, + ) + sample_arr += [sample] + + if len(sample_arr) > max_num_group: + logger.warning(f"Warning: allow {max_num_group} groups, but got {len(sample_arr)} groups") + import random + + sample_arr = random.sample(sample_arr, max_num_group) # preserve max_num_group groups + + return sample_arr + + def tokenize_steps(self, ext_steps: List[ExtendedMessage], index: int, total_steps: int) -> dict: + """ + Create an Experience object from the current conversation context. + + Returns: + Experience: An object containing processed conversation data for model training + + Note: + - Removes the last user message as it's not required in casual model training + - Processes input IDs, attention masks, and loss masks + - Separates prompt and response components + - Handles position IDs and reward scores + - Truncates output IDs as needed + """ + from verl.utils.model import compute_position_id_with_mask + + ext_steps = self.remove_last_non_llm_msg(ext_steps) + + # check reward structure + self.reward_structure: Reward # type: ignore + assert self.reward_structure.step_reward_arr is not None, "must call `process_reward` before tokenize_steps" + assert len(self.reward_structure.step_reward_arr) == total_steps + + # mapping + input_ids = [] + input_logprobs = [] + attention_mask = [] + loss_mask = [] + split_prompt_reponse_index = -1 + split_point_message_left_index = -1 + input_ids_len = [] + + # cat all messages + for i, ext_msg in enumerate(ext_steps): + # find split index, this have to be done before input_ids += ext_msg.token_arr + if (split_prompt_reponse_index == -1) and (ext_msg.need_training): + split_prompt_reponse_index = len(input_ids) + split_point_message_left_index = i - 1 + assert split_point_message_left_index >= 0, "There should be at least one message before the first training message" + assert split_prompt_reponse_index == input_ids_len[split_point_message_left_index] + assert ext_msg.author == "llm", "The first message after initialization should be from LLM, not from env or user" + + # cat all tokens + input_ids += ext_msg.token_arr + if len(ext_msg.token_logprob_arr) == 0: + input_logprobs += [ext_msg.invalid_log_prob_value] * len(ext_msg.token_arr) + else: + input_logprobs += ext_msg.token_logprob_arr + input_ids_len += [len(input_ids)] + attention_mask += [1] * len(ext_msg.token_arr) + loss_mask += ext_msg.get_loss_mask(blackout_token_combo=self.blackout_token_combo) + + # if [prompt_token | response_token] is splited at a place where loss_mask == 0, + # move the split index forward + MAX_FORWARD_STEPS = 100 + for i in range(MAX_FORWARD_STEPS): + if loss_mask[split_prompt_reponse_index] == 0: + split_prompt_reponse_index += 1 + else: + break + + # no matter what, the split index should not exceed max prompt length + # make sure that the prompt length does not exceed `config.ajet.data.max_prompt_length` + if split_prompt_reponse_index > self.config.ajet.data.max_prompt_length: + split_prompt_reponse_index = self.config.ajet.data.max_prompt_length + + # check + assert len(ext_steps) == len(input_ids_len), "length of ext_steps and input_ids_len should be equal" + assert split_prompt_reponse_index != -1, "split_prompt_reponse_index should not be -1, at least one message should be in the context" + position_ids = compute_position_id_with_mask(torch.tensor(attention_mask)).tolist() + + # sperate prompt and response + prompt_ids = input_ids[:split_prompt_reponse_index] + prompt_attention_mask = attention_mask[:split_prompt_reponse_index] + prompt_position_ids = position_ids[:split_prompt_reponse_index] + prompt_loss_mask = loss_mask[:split_prompt_reponse_index] + prompt_logprobs = input_logprobs[:split_prompt_reponse_index] + + response_ids = input_ids[split_prompt_reponse_index:] + response_attention_mask = attention_mask[split_prompt_reponse_index:] + response_position_ids = position_ids[split_prompt_reponse_index:] + response_loss_mask = loss_mask[split_prompt_reponse_index:] + response_logprobs = input_logprobs[split_prompt_reponse_index:] + + tracker_tokenized = {} + tracker_tokenized["input_ids"] = input_ids + tracker_tokenized["prompt_ids"] = prompt_ids + tracker_tokenized["response_ids"] = response_ids + tracker_tokenized["attention_mask"] = attention_mask + tracker_tokenized["logprobs"] = input_logprobs + tracker_tokenized["prompt_attention_mask"] = prompt_attention_mask + tracker_tokenized["response_attention_mask"] = response_attention_mask + tracker_tokenized["loss_mask"] = loss_mask + tracker_tokenized["prompt_loss_mask"] = prompt_loss_mask + tracker_tokenized["response_loss_mask"] = response_loss_mask + tracker_tokenized["position_ids"] = position_ids + tracker_tokenized["prompt_position_ids"] = prompt_position_ids + tracker_tokenized["response_position_ids"] = response_position_ids + tracker_tokenized["response_logprobs"] = response_logprobs + tracker_tokenized["prompt_logprobs"] = prompt_logprobs + + # distribute reward + tracker_tokenized["step_reward"] = self.reward_structure.step_reward_arr[index] + try: + tracker_tokenized["reference_advantage"] = self.reward_structure.step_advantage[index] + except Exception: + tracker_tokenized["reference_advantage"] = 0 + + return tracker_tokenized + + @staticmethod + def compute_reference_advantage(tracker_array: List): + import numpy as np + + task2tracker = defaultdict(list) + for tracker in tracker_array: + task2tracker[tracker.task_id] += [tracker] + + # compute group normalized step_advantage (just for logging purpose) + for task_id, tracker_list in task2tracker.items(): + tracker_reward = [] + + # compute in-group mean and std + for tracker in tracker_list: + tracker_reward += [np.mean(tracker.reward_structure.step_reward_arr)] + + if len(tracker_reward) == 1: + reward_mean = 0.0 + reward_std = 1.0 + else: + reward_mean = float(np.mean(tracker_reward)) + reward_std = float(np.std(tracker_reward, ddof=1)) + if reward_std < 0.01: + reward_std = 0.01 + + # compute advantage + for tracker in tracker_list: + tracker.reward_structure.step_advantage = [] + for i in range(len(tracker.reward_structure.step_reward_arr)): + tracker.reward_structure.step_advantage += [(tracker.reward_structure.step_reward_arr[i] - reward_mean) / (reward_std + 1e-6)] + + # compute simple advantage (uneven rollout sample count) (just for logging purpose) + for task_id, tracker_list in task2tracker.items(): + tracker_reward = [] + for tracker in tracker_list: + tracker_reward.extend(tracker.reward_structure.step_reward_arr) + if len(tracker_reward) == 1: + reward_mean = 0.0 + reward_std = 1.0 + else: + reward_mean = float(np.mean(tracker_reward)) + reward_std = float(np.std(tracker_reward, ddof=1)) + for tracker in tracker_list: + tracker.reward_structure.step_advantage_simple = [] + for i in range(len(tracker.reward_structure.step_reward_arr)): + tracker.reward_structure.step_advantage_simple += [(tracker.reward_structure.step_reward_arr[i] - reward_mean) / (reward_std + 1e-6)] + return + + def get_generation_prompt_token(self): + dummy_msg = [{"role": "assistant", "content": "dummy text"}] + self.generation_prompt_token, _ = self.get_inc( + ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=dummy_msg, + tools=[], + add_generation_prompt=False, + tokenize=False, + ), + ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=dummy_msg, + tools=[], + add_generation_prompt=True, + tokenize=False, + ), + ) + self.generation_prompt = self.tokenizer.decode(self.generation_prompt_token) + return self.generation_prompt_token + + def generate_log(self, task_id=None, global_step: str | int = "NA"): + """ + Generate log for the context tracker. + """ + raise NotImplementedError diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py new file mode 100644 index 00000000..ee9138b9 --- /dev/null +++ b/ajet/context_tracker/multiagent_tracking.py @@ -0,0 +1,531 @@ +# flake8: noqa: F541, F841 +import copy +import json +from dataclasses import dataclass, field +from typing import List, Tuple + +from beast_logger import NestedJsonItem, SeqItem, print_dict, print_nested +from loguru import logger +from transformers.tokenization_utils import PreTrainedTokenizer + +from ajet.context_tracker.timeline_merging.timeline_merging import merge_tracker_timelines, is_timeline_mergeable +from ajet.context_tracker.basic_tracker import ( + BaseContextTracker, + ExtendedMessage, +) +from ajet.schema.extended_msg import INVALID_LOG_PROB_VALUE +from ajet.schema.trajectory import Reward +from ajet.utils.color_hsl import adjust_color_hsl +from ajet.utils.compute_madness import compute_string_madness +from ajet.utils.tokenizer import ajet_apply_chat_template + + +@dataclass +class TimelineMergingPolicyConfig: + timeline_compare_level: str = "text" + ignore_tools: bool = True + + +@dataclass +class ContextTrackerConfig: + timeline_merging_policy: TimelineMergingPolicyConfig = field(default_factory=TimelineMergingPolicyConfig) + fix_retokenization_drift: bool = True + detect_timeline_snap: bool = False + + +class MultiAgentContextTracker(BaseContextTracker): + """ + Context tracker is responsible to monitor and process LLM IO. + Each context tracker is responsible for ONE episode run only. + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + config, + should_interrupt_fn, + generated_token_callback_fn, + **kwargs, + ): + super().__init__(config, tokenizer, **kwargs) + self.tokenizer = tokenizer + self.should_interrupt_fn = should_interrupt_fn + self.generated_token_callback_fn = generated_token_callback_fn + self.context_overflow = False + self.output_kwargs = {} + self.input_kwargs = {} + self.timeline_cache = {} + + def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool = False): + if disable_toolcalls: + tools = [] + else: + if tools is not None: + # rerank tool parameters to improve compatibility + for i in range(len(tools)): + tools[i]["function"]["parameters"] = tools[i]["function"].pop("parameters") + return tools + + def extract_text_content_from_content_dict(self, msg): + # msg = { + # "role": "assistant", + # "content": [ + # { + # "type": "text", + # "text": "some text" + # }, + # ], + # } + + str_content = "" + for item in msg["content"]: + # item = { + # "type": "text", + # "text": "some text" + # }, + + assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}" + + if "text" not in item: + logger.warning(f"Non-text content in message content detected: {item}. Ignoring.") + should_skip_message = True + return str_content, should_skip_message + + if isinstance(item["text"], str): + str_content += str(item["text"]) + else: + str_content = "" + + should_skip_message = False + return str_content, should_skip_message + + def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_toolcalls: bool = False) -> List[ExtendedMessage]: + """Spawn a timeline from messages. + + Args: + messages: List of message dictionaries + tools: List of tool dictionaries + disable_toolcalls: Whether to disable tool calls + + Returns: + List of ExtendedMessage objects representing the timeline + """ + timeline = [] + + consider_roles = ["user", "assistant", "system", "tool"] + if disable_toolcalls: + consider_roles.remove("tool") + + for i, msg in enumerate(messages): + if (disable_toolcalls) and (not isinstance(msg["content"], str)): + continue + + if msg["role"] not in consider_roles: + continue + + if not isinstance(msg["content"], str): + author = "env" + should_skip_message = False + + # fix msg content + if msg["content"] is None: + msg["content"] = "" + + elif isinstance(msg["content"], list): + msg["content"], should_skip_message = self.extract_text_content_from_content_dict(msg) + + else: + raise ValueError(f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}") + + if should_skip_message: + continue + + if not isinstance(msg["content"], str): + msg["content"] = str(msg["content"]) # TODO: better handling mm data + + if msg["role"] == "system": + author = "initialization" + + if msg["role"] == "tool": + author = "env" + else: + author = "env" + + timeline += [ + ExtendedMessage( + author=author, + role=msg["role"], + content=msg["content"], + tokenizer=self.tokenizer, + tools=tools, + tool_calls=(msg["tool_calls"] if "tool_calls" in msg else []), + tool_call_id=(msg["tool_call_id"] if "tool_call_id" in msg else ""), + token_generator="auto", + name=(msg["name"] if "name" in msg else ""), + first_message=(i == 0), + ) + ] + + return timeline + + def step_prepare(self, messages: List[dict], tools: List = [], timeline_uuid: str = ""): + disable_toolcalls = self.config.ajet.rollout.force_disable_toolcalls + tools = self.preprocess_tools_field(tools, disable_toolcalls=disable_toolcalls) + timeline = self.step_spawn_timeline(messages, tools, disable_toolcalls) + + # check token overflow + converted_message = self.to_role_content(timeline) + timeline = ExtendedMessage.check_and_merge_chained_tool_response(timeline, self.tokenizer) + context_safe, token_overflow, info = self.check_context_token_num_safe(converted_message, tools) + custom_sampling_params = {} + if not context_safe: + self.context_overflow = True + + self.timeline_cache[timeline_uuid] = timeline + return context_safe, token_overflow, info, converted_message, custom_sampling_params, tools + + def step_track( + self, + llm_output, + context_safe, + converted_message: List[dict], + tools: List = [], + timeline_uuid: str = "", + ): + assert timeline_uuid in self.timeline_cache, "Timeline UUID not found in cache. Please ensure `step_prepare` is called before `step_track`." + timeline = self.timeline_cache.get(timeline_uuid, []) + if not self.already_mad_flag: + if ( + compute_string_madness( + completion=llm_output["content"], + checklist=self.config.ajet.rollout.compute_madness_checklist, + ) + < 0.0 + ): + self.already_mad_flag = True + + tool_calls = self.detect_tool_call_madness(llm_output) + + llm_ext_msg = ExtendedMessage( + author="llm", + role="assistant", + content=llm_output["content"], + token_generator="manual", + tool_calls=tool_calls, + tokenizer=self.tokenizer, + ) + input_msg_ref = copy.deepcopy(converted_message) + ( + precise_manual_token, + token_logprob_arr, + loss_mask, + lack_normal_eos, + ) = self.get_token_inc_from_llm_response(input_msg_ref, llm_output, tools=tools) + llm_ext_msg.token_arr = precise_manual_token + llm_ext_msg.token_logprob_arr = token_logprob_arr + llm_ext_msg.lack_normal_eos = lack_normal_eos + llm_ext_msg.manual_loss_mask_override = loss_mask + + assert len(precise_manual_token) <= self.config.ajet.rollout.max_response_length_in_one_turn, f"Generated token length {len(precise_manual_token)} exceeds max_response_length_in_one_turn {self.config.ajet.rollout.max_response_length_in_one_turn}" + + # run generated token callback, usually to monitor token output rate ( e.g. 164 tokens/sec ) + self.generated_token_callback_fn(llm_ext_msg.token_arr) + + # take snapshot of current timeline + if context_safe: + if "prompt_text" in llm_output and "prompt_token_ids" in llm_output: + # currently we make this patch to better compat with Trinity training backend + # fix Retokenization Drift + timeline = self.patch_prompt_tokens( + prompt_text=llm_output["prompt_text"], + prompt_token_ids=llm_output["prompt_token_ids"], + previous_ext_context=timeline, + ) + + self.save_llm_interaction_timeline(tools, llm_ext_msg, timeline) + return None + + def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline): + """Save the LLM interaction timeline by adding the LLM response to `self.saved_timelines`""" + timeline += [llm_ext_msg] + _, length = self.get_context_token_num_and_safety(timeline, tools) + if length > self.config.ajet.rollout.max_model_len: + raise RuntimeError(f"Unexpected token overflow after adding LLM response. Full context length {length}, generated token length {len(llm_ext_msg.token_arr)}") + + assert timeline[0].first_message, "First message should be marked as first_message" + + # assert all other message is not first_message + for i in range(1, len(timeline)): + assert not timeline[i].first_message + + # save to self.saved_timelines + self.saved_timelines += [copy.deepcopy(timeline)] + + # DEBUG = True # warn when merge fails + timeline_merging_policy: TimelineMergingPolicyConfig = self.config.ajet.context_tracker.timeline_merging_policy + if self.config.ajet.context_tracker.detect_timeline_snap and len(self.saved_timelines) >= 2 and (not is_timeline_mergeable(self.saved_timelines[-1], self.saved_timelines[-2], timeline_merging_policy)): + logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n") + return + + def detect_tool_call_madness(self, llm_output): + """Detect whether the tool call format from LLM output is correct or not.""" + log_tool = self.config.ajet.context_tracker.log_tool_format_check + detailed_log = self.config.ajet.context_tracker.log_tool_format_error_detail + + err_type = "" + if llm_output.get("tool_calls", []): + # llm_output["tool_calls"] is not None, and is not [] + tool_calls = llm_output["tool_calls"] + if "wrong_toolcall" in self.config.ajet.rollout.compute_madness_checklist: + copy_tool_calls = copy.deepcopy(tool_calls) + wrong_toolcall = False + for i in range(len(copy_tool_calls)): + if ("function" in copy_tool_calls[i]) and ("arguments" in copy_tool_calls[i]["function"]): + try: + expect_dict = json.loads(copy_tool_calls[i]["function"]["arguments"]) + if not isinstance(expect_dict, dict): + wrong_toolcall = True + err_type = "cannot parse arguments" + except Exception: + wrong_toolcall = True + err_type = "arguments not json" + else: + wrong_toolcall = True + err_type = "no function or no arguments" + if wrong_toolcall: + if detailed_log: + logger.bind(exception=True).warning(f"Detected wrong toolcall format from LLM output: \n---*({err_type})*---\n{llm_output['tool_calls']}\n---*-*---\n") + if log_tool: + logger.bind(exception=True).warning(f"Detected wrong toolcall format from LLM content") + self.already_mad_flag = True + else: + if log_tool: + logger.success("Toolcall format check passed.") + + elif "" in llm_output["content"]: + if detailed_log: + logger.bind(exception=True).warning(f"Detected wrong toolcall format from LLM content: \n---*-*---\n{llm_output['content']}\n---*-*---\n") + if "wrong_toolcall" in self.config.ajet.rollout.compute_madness_checklist: + if log_tool: + logger.bind(exception=True).warning(f"Detected wrong toolcall format from LLM content") + self.already_mad_flag = True + tool_calls = [] + else: + tool_calls = [] + return tool_calls + + def patch_prompt_tokens( + self, + prompt_text: str, + prompt_token_ids: List[int], + previous_ext_context: List[ExtendedMessage], + ) -> List[ExtendedMessage]: + # remove tailing + if prompt_text.endswith(self.generation_prompt): + prompt_text = prompt_text[: -len(self.generation_prompt)] + # prompt_token_ids = prompt_token_ids[: -len(self.generation_prompt_token)] + + # split prompt token ids into message level + split_prompt_token_ids = [] + tmp = [] + for i in range(len(prompt_token_ids)): + if prompt_token_ids[i] != self._im_start_token_id: + tmp += [prompt_token_ids[i]] + else: + if len(tmp) > 0: + split_prompt_token_ids += [tmp] + tmp = [prompt_token_ids[i]] + if len(tmp) > 0: + split_prompt_token_ids += [tmp] + + # split prompt text into message level + prompt_text_split = prompt_text.split("<|im_start|>") + assert prompt_text_split[0] == "", "Prompt text should start with <|im_start|>" + prompt_text_split = prompt_text_split[1:] # remove the first empty string + for i in range(len(prompt_text_split)): + prompt_text_split[i] = "<|im_start|>" + prompt_text_split[i] + + current_prompt_text = [] + for j in range(len(previous_ext_context)): + current_prompt_text += [self.tokenizer.decode(previous_ext_context[j].token_arr)] + + if len(previous_ext_context) != len(prompt_text_split): + logger.bind(exception=True).error(f"Length mismatch when patching prompt tokens. Previous ext context length: {len(previous_ext_context)}, prompt text split length: {len(prompt_text_split)}. Replacing all tokens.") + + # try to recover tokens + if self.config.ajet.context_tracker.fix_retokenization_drift: + self.ensure_retokenization_perfect_match(previous_ext_context, split_prompt_token_ids, prompt_text_split, current_prompt_text) + + # remove extra messages + if len(previous_ext_context) != len(prompt_text_split): + previous_ext_context = previous_ext_context[: len(prompt_text_split)] + + return previous_ext_context + + def ensure_retokenization_perfect_match(self, previous_ext_context, split_prompt_token_ids, prompt_text_split, current_prompt_text): + for j in range(len(previous_ext_context)): + if prompt_text_split[j] != current_prompt_text[j]: + # if prompt text mismatch, we can replace the tokens + print_dict( + { + "expected_prompt_text": prompt_text_split[j], + "current_prompt_text": current_prompt_text[j], + }, + mod="exception", + header="Prompt text mismatch, Please report a github issue", + ) + previous_ext_context[j].token_arr = self.tokenizer(prompt_text_split[j], return_tensors="pt", padding=False) + else: + # if prompt text match + # we further check whether all token ids matches + vllm_token_array = split_prompt_token_ids[j] + tracker_token_array = previous_ext_context[j].token_arr + if vllm_token_array == tracker_token_array: + # good, everything is perfect + continue + else: + # otherwise, we throw a warning (do not worry, this causes almost no influence in the training) + print_dict( + { + "expected_token_ids": split_prompt_token_ids[j], + "current_token_ids": previous_ext_context[j].token_arr, + }, + mod="exception", + header="Prompt token ids mismatch, Please report a github issue", + ) + + def process_reward(self, reward_structure: Reward): + self.reward_structure = reward_structure + # TODO: support multi-step reward + # in current implementation, all reward in all step equals + # we'll implement fine-grained step reward in future versions + self.reward_structure.step_reward_arr = [ + self.compute_step_level_reward( + index=i, + total_steps=len(self.saved_timelines), + ) + for i in range(len(self.saved_timelines)) + ] + + def generate_log(self, task_id=None, global_step="NA"): + task_id = self.task_id + nested_items_print_buffer = {} + step_reward = 0.0 + + for index, ext_steps in enumerate(self.saved_timelines): + tracker_tokenized = self.tokenize_steps( + ext_steps=ext_steps, + index=index, + total_steps=len(self.saved_timelines), + ) + text_arr = [self.tokenizer.decode(t) for t in tracker_tokenized["input_ids"]] + input_id_arr = [str(t) for t in tracker_tokenized["input_ids"]] + # loss_mask_color_arr = ["#09ABCF" if mask==1 else "#D98510" for mask in tracker_tokenized["loss_mask"]] + logprobs = [INVALID_LOG_PROB_VALUE] * len(tracker_tokenized["prompt_ids"]) + tracker_tokenized["response_logprobs"] + # Create adjusted color array + loss_mask_color_abl_arr = [(adjust_color_hsl("#09ABCF", logprob) if mask == 1 else adjust_color_hsl("#D98510", logprob)) for mask, logprob in zip(tracker_tokenized["loss_mask"], logprobs)] + logprob_text_arr = [(f"{logprob:.4f}" if logprob != INVALID_LOG_PROB_VALUE else "N/A") for logprob in logprobs] + + buffer = { + "text_arr": text_arr, + "logprob_arr": logprob_text_arr, + "input_id_arr": input_id_arr, + "loss_mask_color_arr": loss_mask_color_abl_arr, + } + raw_reward = self.reward_structure.raw_reward + step_reward: float = self.reward_structure.step_reward_arr[index] + try: + step_advantage = self.reward_structure.step_advantage[index] + step_advantage_simple = self.reward_structure.step_advantage_simple[index] + except Exception: + step_advantage = 0.0 + step_advantage_simple = 0.0 + task_outcome = str(self.reward_structure.success_rate) + selectors = [task_id, task_outcome, str(index)] + len_prompt_ids = len(tracker_tokenized["prompt_ids"]) + len_response_ids = len(tracker_tokenized["response_ids"]) + len_input_ids = len(tracker_tokenized["input_ids"]) + assert len_prompt_ids + len_response_ids == len_input_ids, "len_prompt_ids + len_response_ids should equal to len_input_ids" + nested_items_print_buffer[".".join(selectors)] = NestedJsonItem( + item_id="item", # type: ignore + outcome=task_outcome, # type: ignore + len_prompt_ids=len_prompt_ids, # type: ignore + len_response_ids=len_response_ids, # type: ignore + len_input_ids=len_input_ids, # type: ignore + raw_reward=f"{float(raw_reward):.3f}", # type: ignore + step_reward=f"{float(step_reward):.3f}", # type: ignore + step_advantage=f"{float(step_advantage):.3f}", # type: ignore + step_advantage_simple=f"{float(step_advantage_simple):.3f}", # type: ignore + content=SeqItem( + text=buffer["text_arr"], # text content + title=buffer["logprob_arr"], # mouse hover text + count=buffer["input_id_arr"], # highlight text # type: ignore + color=buffer["loss_mask_color_arr"], # color + ), + ) + + print_nested( + nested_items_print_buffer, + main_content="This is the main content of the nested JSON", + header=f"[{global_step}] Task {task_id} (Reward {float(step_reward):.3f})", # type: ignore + mod="rollout", + narrow=False, + attach="copy this", # type: ignore + ) + + def group_merge(self) -> List[List[ExtendedMessage]]: + timeline_merging_policy: TimelineMergingPolicyConfig = self.config.ajet.context_tracker.timeline_merging_policy + self.saved_timelines = merge_tracker_timelines(self.saved_timelines, timeline_merging_policy) + return self.saved_timelines + + def group_tokenize(self): + return self.group_tokenize_multi_group() + + def get_context_token_num_and_safety(self, ext_messages: List[ExtendedMessage], tools: List = []) -> Tuple[bool, int]: # type: ignore + dict_messages = self.to_role_content(ext_messages) + prompt_text = ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=dict_messages, + tools=tools, + add_generation_prompt=True, + tokenize=False, + ) + length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore + max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn + max_model_len: int = self.config.ajet.rollout.max_model_len + max_seq_length: int = max_model_len - max_response_length + + if length < max_seq_length: + return True, length + else: + return False, length + + def check_context_token_num_safe(self, messages: List, tools: List = []) -> Tuple[bool, bool, str]: + prompt_text = ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=messages, + tools=tools, + add_generation_prompt=True, + tokenize=False, + ) + length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) # type: ignore + max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn + max_model_len: int = self.config.ajet.rollout.max_model_len + max_seq_length: int = max_model_len - max_response_length + if length < max_seq_length: + token_overflow = False + else: + token_overflow = True + if self.should_interrupt_fn(): + ret = (False, token_overflow, "externally_interrupted") + elif self.already_mad_flag and self.config.ajet.rollout.agent_madness_termination: + ret = (False, token_overflow, "already_mad") + elif length < max_seq_length: + ret = ( + True, + token_overflow, + f"safe[{length} < {max_model_len} - {max_response_length}]", + ) + else: + ret = (False, token_overflow, "token_overflow") + return ret diff --git a/ajet/context_tracker/timeline_merging/timeline_merging.py b/ajet/context_tracker/timeline_merging/timeline_merging.py new file mode 100644 index 00000000..3f5ce8ad --- /dev/null +++ b/ajet/context_tracker/timeline_merging/timeline_merging.py @@ -0,0 +1,117 @@ +from typing import List + +from beast_logger import print_listofdict + +from ajet.context_tracker.basic_tracker import ExtendedMessage + + +def is_timeline_mergeable( + source_timeline: List[ExtendedMessage], + target_timeline: List[ExtendedMessage], + timeline_merging_policy, + debug=False, +) -> bool: + # timeline_compare_level = "text" # relaxed compare with text, more easier to match, at very little cost + # timeline_compare_level = "token" # strict compare with token, cause less aggressive merging + timeline_compare_level = timeline_merging_policy.get("timeline_compare_level", "text") + ignore_tools = timeline_merging_policy.get("ignore_tools", True) + + can_merge = False + if len(source_timeline) >= len(target_timeline): + all_msg_match = True + for i in range(len(target_timeline)): + if timeline_compare_level == "text": + same = source_timeline[i].content_for_future == target_timeline[i].content_for_future + elif timeline_compare_level == "token": + same = source_timeline[i].token_arr == target_timeline[i].token_arr + else: + raise NotImplementedError + + if not same: + all_msg_match = False + break + + # compare whether avail tool list is identical when (not ignore_tools) + if (all_msg_match) and (not ignore_tools): + source_0 = source_timeline[0] + target_0 = target_timeline[0] + if source_0.tools != target_0.tools: + all_msg_match = False + + if all_msg_match: + can_merge = True + + # # developer only: code below is only for debugging (print a nice comparison table) + # if debug: + # debug_listofdict = [] + # if len(source_timeline) >= len(target_timeline): + # all_msg_match = False + # for i in range(len(target_timeline)): + # d = {} + # d["source"] = source_timeline[i].content_for_future + # d["target"] = target_timeline[i].content_for_future + # if timeline_compare_level == "text": + # same = ( + # source_timeline[i].content_for_future + # == target_timeline[i].content_for_future + # ) + # elif timeline_compare_level == "token": + # same = source_timeline[i].token_arr == target_timeline[i].token_arr + # else: + # raise NotImplementedError + # if not same: + # d["match"] = "NO" + # else: + # d["match"] = "YES" + # debug_listofdict.append(d) + # print_listofdict(debug_listofdict, header=f"is_timeline_mergeable debug: {can_merge}") + + return can_merge + + +def toggle_author_and_mask( + source_timeline: List[ExtendedMessage], + target_timeline: List[ExtendedMessage], +) -> List[ExtendedMessage]: + # if any message in `target_timeline` is author == 'llm', + # but same-index message in `source_timeline` is author != 'llm' + # change source_timeline's message author to 'llm' + # overwrite `token_arr` and `token_logprob_arr` accordingly + for i in range(len(target_timeline)): + if target_timeline[i].author == "llm" and source_timeline[i].author != "llm": + source_timeline[i].author = target_timeline[i].author + source_timeline[i].token_arr = target_timeline[i].token_arr + source_timeline[i].token_logprob_arr = target_timeline[i].token_logprob_arr + assert source_timeline[i].need_training + return source_timeline + + +def merge_tracker_timelines(timelines: List[List[ExtendedMessage]], timeline_merging_policy) -> List[List[ExtendedMessage]]: + """Merge multiple timelines by absorbing those that can be merged. + > Input: a list of timelines. (a timeline means: List[ExtendedMessage]) + > Output: a shorter list of timelines after merging + """ + absorbed_step_indices = [] + reversed_timelines = list(reversed(timelines)) + for i in range(len(reversed_timelines)): + if i in absorbed_step_indices: + continue + # check whether [i, i+1, ..., ..., len(reversed_timelines)-1] can be merged + for j in range(i + 1, len(reversed_timelines)): + if j in absorbed_step_indices: + continue + source_timeline = reversed_timelines[i] + target_timeline = reversed_timelines[j] + if is_timeline_mergeable(source_timeline, target_timeline, timeline_merging_policy): + source_timeline = toggle_author_and_mask(source_timeline, target_timeline) + reversed_timelines[i] = source_timeline + absorbed_step_indices += [j] + + # reverse back and exclude absorbed steps + reversed_timelines_clean = [] + for i in range(len(reversed_timelines)): + if i not in absorbed_step_indices: + reversed_timelines_clean.append(reversed_timelines[i]) + timelines = list(reversed(reversed_timelines_clean)) + + return timelines diff --git a/astune/env_service_client/__init__.py b/ajet/copilot/__init__.py similarity index 100% rename from astune/env_service_client/__init__.py rename to ajet/copilot/__init__.py diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py new file mode 100644 index 00000000..37ba2c52 --- /dev/null +++ b/ajet/copilot/job.py @@ -0,0 +1,168 @@ +"""Programmatic training entry point for AgentJet. + +This class mirrors the CLI launcher by materializing a YAML config and +spawning a subprocess to run the existing training pipeline. The goal is to +keep the public surface minimal while reusing the mature CLI code paths. +""" + +from __future__ import annotations + +import os +import tempfile +from datetime import datetime +from types import SimpleNamespace +from typing import Any, Callable, Union + +import ray +import yaml +from loguru import logger + +from ajet.launcher import ( + check_avail_gpu, + get_backbone_target, + setup_environment_vars, +) +from ajet.default_config.ajet_default import Config +from ajet.utils.config_utils import ( + expand_ajet_hierarchical_config, + prepare_experiment_config, + read_ajet_hierarchical_config, +) +from ajet.utils.dynamic_import import cls_to_path +from ajet.utils.launch_utils import execute_training_process + + +class AgentJetJob: + """Lightweight builder that launches AgentJet training as a subprocess.""" + + def __init__( + self, + backbone: str = "trinity", + model: str = "Qwen/Qwen2___5-7B-Instruct", + n_gpu: int = 8, + algorithm: str = "grpo", + n_gpu_for_infer: int | None = None, # only for trinity backbone + *kwargs, + ) -> None: + self.backbone = backbone + self.config_as_dict: dict = self.build_job_from_yaml(None) + self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict) + + self.config.ajet.backbone = backbone + self.config.ajet.model.path = model + self.config.ajet.trainer_common.n_gpus_per_node = n_gpu + self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm + if n_gpu_for_infer is None and backbone == "trinity": + raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") + if n_gpu_for_infer is not None and backbone == "verl": + raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.") + else: + assert isinstance(n_gpu_for_infer, int) + assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." + self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer + self.config.ajet.rollout.tensor_model_parallel_size = 1 + + def build_job_from_yaml(self, yaml_path: str | None) -> dict: + self.exp_name = datetime.now().strftime("ajet_job_%Y%m%d_%H%M%S") + self.exp_dir_final = "saved_experiments" + self.config_as_dict = read_ajet_hierarchical_config( + yaml_path, + exp_name=self.exp_name, + backbone=self.backbone, + write_to=None, + exp_dir=self.exp_dir_final, + ) + self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None) + logger.info(f"Built AgentJet job config: {yaml_path}") + return self.config_as_dict + + def dump_job_as_yaml(self, yaml_path: str) -> str: + if os.path.dirname(yaml_path): + os.makedirs(os.path.dirname(yaml_path), exist_ok=True) + with open(yaml_path, "w", encoding="utf-8") as f: + yaml.safe_dump(self.config.to_dict(), f, sort_keys=False) + logger.info(f"Saved training config to {yaml_path}") + return yaml_path + + def set_workflow(self, workflow: Union[str, Callable[..., Any]], ensure_reward_in_workflow: bool = False) -> "AgentJetJob": + self.config.ajet.rollout.user_workflow = cls_to_path(workflow) + # TODO: validate workflow outputs contain reward + # ensure_reward_in_workflow + return self + + def set_data( + self, + type: str, + dataset_path: str, + training_split: str = "train", + validation_split: str = "test", + ) -> "AgentJetJob": + """Configure the task reader. Defaults to HuggingFace datasets.""" + + # available types: + # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + if type in {"hf", "huggingface", "huggingface_dat_repo"}: + self.config.ajet.task_reader.type = "huggingface_dat_repo" + self.config.ajet.task_reader.huggingface_dat_repo.dataset_path = dataset_path + self.config.ajet.task_reader.huggingface_dat_repo.training_split = training_split + self.config.ajet.task_reader.huggingface_dat_repo.validation_split = validation_split + elif type in {"random_dummy", "dummy"}: + self.config.ajet.task_reader.type = "random_dummy" + else: + raise NotImplementedError(f"Please edit yaml to directly set up task reader of type {type}.") + + return self + + def tune(self, *args, **kwargs) -> "AgentJetJob": + ast_cfg = self.config.ajet + if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow: + raise ValueError("Workflow must be set via set_workflow before tuning.") + if not ast_cfg.task_reader: + raise ValueError("Data source must be set via set_data before tuning.") + + backbone = self.config.ajet.backbone + exp_dir = self.config.ajet.experiment_dir + + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml: + yaml_path = temp_yaml.name + self.dump_job_as_yaml(yaml_path) + args = SimpleNamespace( + conf=yaml_path, + backbone=backbone, + exp_dir=exp_dir, + with_logview=False, + debug=False, + ) + + if args.backbone != "debug": + # Enforce GPU availability and free memory threshold before proceeding + check_avail_gpu(min_free_ratio=0.95) + + # finalize experiment config + main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(yaml_path, exp_dir, backbone) + + # setup environment variables for ray + env = setup_environment_vars(args, exp_config, main_yaml_fp) + + # start ray if not already started + if not ray.is_initialized(): + from ajet.utils.launch_utils import start_ray_service + + start_ray_service(args, env) + else: + raise RuntimeError("Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job.") + + # start training process + if args.conf and main_yaml_fp and exe_exp_base and exp_config: + execute_training_process( + args, + get_backbone_target(args.backbone), + main_yaml_fp, + exe_exp_base, + main_yaml_fp, + env, + exp_config, + ) + + return self diff --git a/astune/schema/__init__.py b/ajet/data_generator/__init__.py similarity index 100% rename from astune/schema/__init__.py rename to ajet/data_generator/__init__.py diff --git a/ajet/data_generator/base_data_generator.py b/ajet/data_generator/base_data_generator.py new file mode 100644 index 00000000..0ee47d8e --- /dev/null +++ b/ajet/data_generator/base_data_generator.py @@ -0,0 +1,77 @@ +from typing import Any, Dict, List, Optional, Union + +from ajet.schema.document import Document +from ajet.schema.task import Task +from ajet.task_rollout.dashscope_llm_bridge import create_external_llm_fn + + +class BaseDataGenerator: + def __init__(self, config): + """ + Initialize the TaskGeneratorBase class. + + Args: + config: Optional configuration object (LLM model, Maximum response length) + """ + self.config = config + self.sampling_params = self.config.data_generation.sampling_params or {} + self.llm_client = create_external_llm_fn( + alien_llm_model=self.config.data_generation.llm_model, + alien_llm_response_length=self.config.data_generation.llm_response_length, + ) + + def generate_task( + self, + source_task: Optional[Task] = None, + document: Optional[Document] = None, + extra_metadata: Optional[Dict[str, Any]] = None, + ) -> Union[Task, List[Task]]: + """ + Generate a new task. + + Args: + source_task: Source task for imitation (optional) + document: Knowledge source (optional) + extra_metadata: Additional metadata for the new task + + Returns: + Generated Task instance + """ + system_prompt = self._build_system_prompt(source_task, document) + user_prompt = self._build_user_prompt(source_task, document) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + # Call the new LLM client + # Returns: {"role": "assistant", "content": "..."} + response = self.llm_client(messages=messages, sampling_params_override=self.sampling_params) + # Extract content from response + raw_response = response.get("content", "") + new_task = self._parse_llm_output_to_task(raw_response, source_task, document, extra_metadata) + return new_task + + def _build_system_prompt( + self, + source_task: Optional[Task], + document: Optional[Document], + ) -> str: + raise NotImplementedError + + def _build_user_prompt( + self, + source_task: Optional[Task], + document: Optional[Document], + ) -> str: + raise NotImplementedError + + def _parse_llm_output_to_task( + self, + raw_response: Any, + source_task: Optional[Task], + document: Optional[Document] = None, + extra_metadata: Optional[Dict[str, Any]] = None, + ) -> Union[Task, List[Task]]: + raise NotImplementedError diff --git a/ajet/data_generator/config.py b/ajet/data_generator/config.py new file mode 100644 index 00000000..64bbadc1 --- /dev/null +++ b/ajet/data_generator/config.py @@ -0,0 +1,93 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field, field_validator + + +class SamplingParamsConfig(BaseModel): + """Sampling parameters configuration""" + + temperature: float = Field(default=0.0, description="Sampling temperature, 0 means greedy decoding") + + +class DeduplicationFilterParamsConfig(BaseModel): + """Deduplication filter parameters configuration""" + + similarity_threshold: float = Field(default=0.8, description="Similarity threshold. Tasks above this value will be filtered.") + db_path: str = Field(default="./.similarity_db", description="Storage path for the similarity database") + model: str = Field(default="text-embedding-v4", description="Embedding model name") + api_key: Optional[str] = Field(default=None, description="API Key. If None, it is loaded from environment variables.") + base_url: str = Field( + default="https://dashscope.aliyuncs.com/compatible-mode/v1", + description="Embedding API base URL", + ) + + +class TrainingDatasetConfig(BaseModel): + """Training dataset configuration""" + + file_path: str = Field(default="", description="Path to the training data file") + + +class DeduplicationFilterConfig(BaseModel): + """Deduplication filter configuration""" + + enabled: bool = Field(default=True, description="Whether to enable the filter") + params: DeduplicationFilterParamsConfig = Field(default_factory=DeduplicationFilterParamsConfig) + + +class DocumentReaderConfig(BaseModel): + """Document reader configuration""" + + document_path: List[str] = Field(default=[], description="List of document file paths") + languages: List[str] = Field(default=["eng"], description="List of document languages") + chunk_size: int = Field(default=5120, description="Chunk size") + split_by: str = Field(default="sentence", description="Split method: sentence, paragraph, character") + cache_enabled: bool = Field(default=True, description="Whether to enable caching") + + @field_validator("split_by") + @classmethod + def validate_split_by(cls, v: str) -> str: + allowed = ["sentence", "paragraph", "character"] + if v not in allowed: + raise ValueError(f"split_by must be one of {allowed}, current value: {v}") + return v + + +class DatasetFileConfig(BaseModel): + """Dataset file configuration""" + + training: TrainingDatasetConfig = Field(default_factory=TrainingDatasetConfig) + + +class QueryReaderConfig(BaseModel): + """Query reader configuration""" + + type: str = Field(default="jsonl_dataset_file", description="Reader type") + jsonl_dataset_file: DatasetFileConfig = Field(default_factory=DatasetFileConfig) + + @field_validator("type") + @classmethod + def validate_type(cls, v: str) -> str: + allowed = ["jsonl_dataset_file", "env_service", "huggingface_dat_repo"] + if v not in allowed: + raise ValueError(f"type must be one of {allowed}, current value: {v}") + return v + + +class DataGenerationConfig(BaseModel): + """Data generation configuration""" + + document_reader: DocumentReaderConfig = Field(default_factory=DocumentReaderConfig) + query_reader: QueryReaderConfig = Field(default_factory=QueryReaderConfig) + task_num: int = Field(default=10, description="Number of tasks to generate") + llm_model: str = Field(default="qwen-long", description="LLM model name") + llm_response_length: int = Field(default=8192, description="LLM maximum response length") + num_workers: int = Field(default=32, description="Number of parallel worker threads") + sampling_params: SamplingParamsConfig = Field(default_factory=SamplingParamsConfig) + deduplication_filter: DeduplicationFilterConfig = Field(default_factory=DeduplicationFilterConfig) + + +class TaskReaderConfig(BaseModel): + """Task reader configuration""" + + data_generation: DataGenerationConfig = Field(default_factory=DataGenerationConfig) diff --git a/ajet/data_generator/knowledge_augmentation.py b/ajet/data_generator/knowledge_augmentation.py new file mode 100644 index 00000000..0663fdd9 --- /dev/null +++ b/ajet/data_generator/knowledge_augmentation.py @@ -0,0 +1,104 @@ +import json +import re +from typing import Any, Dict, List, Optional + +from ajet.data_generator.base_data_generator import BaseDataGenerator +from ajet.schema.document import Document +from ajet.schema.task import Task + + +class KnowledgeAugmentor(BaseDataGenerator): + """ + Knowledge Augmentation: + - Generate new tasks from Document + """ + + def _build_system_prompt( + self, + source_task: Optional[Task] = None, + document: Optional[Document] = None, + ) -> str: + # The template can be read from self.config, but here we hardcode an example for now + return "You are an Expert Question Generation Assistant.\n" "Your task is to read long, complex documents and generate a large set of high-quality, non-repetitive questions that thoroughly cover all aspects of the provided content.\n" "**Global Rules:**\n" "1. Coverage: Cover all sections, topics, major themes, nuanced details, facts, arguments, examples.\n" "2. Diversity: Include factual, conceptual, comparative, analytical, application, and critical thinking questions. Avoid overly trivial or repetitive questions.\n" "3. Quality: Questions must be clear, specific, unique, and relevant to the document. Avoid vague or generic questions.\n" "4. Depth: Include multi-step reasoning, chronological, cause-effect, data-driven, and abstract-contextual questions.\n" "5. Formatting: Output must be in a JSON list of dictionaries, each dictionary containing `query` and `related_doc` keys.\n" " - `query` = the generated question (one sentence, ending with a question mark)\n" " - `related_doc` = the exact excerpt or closely matching text from the document that related_docs or relates to the question\n" "6. Boundaries: The `related_doc` field must be taken directly from the provided document; do not fabricate or introduce information from outside sources.\n" "7. Few-shot: If given sample questions, match style and complexity but ensure diversity.\n" "8. Non-repetition: Ensure no two questions are duplicates or paraphrases of the same idea. If content overlaps, merge rather than replicate.\n" "Always strictly follow these rules in every output." + + def _build_user_prompt( + self, + source_task: Optional[Task] = None, + document: Optional[Document] = None, + ) -> str: + if document is None or not document.content: + raise ValueError("KnowledgeAugmentor requires a document for reference.") + + ref_doc = document.content + + user_part = [] + N = 10 # 10 is the hyperparameter we found that produces relatively stable outputs + user_part.append(f"Generate exactly {N} unique, high-quality questions from the following document according to the rules in the system prompt above.") + user_part.append("For each question, provide the corresponding reference excerpt from the document in the `related_doc` field.") + user_part.append("[DOCUMENT START]") + user_part.append(ref_doc) + user_part.append("[DOCUMENT END]") + user_part.append("Now generate queries that is suitable for the JSON format.") + user_part.append("Return your output strictly in JSON format as follows:") + user_part.append("[") + user_part.append(' {"query": "Question text here?", "related_doc": "Direct excerpt from the document here."},') + user_part.append(' {"query": "Question text here?", "related_doc": "Direct excerpt from the document here."},') + user_part.append("]") + return "\n".join(user_part) + + def _parse_llm_output_to_task( + self, + raw_response: Any, + source_task: Optional[Task] = None, + document: Optional[Document] = None, + extra_metadata: Optional[Dict[str, Any]] = None, + ) -> List[Task]: + """ + Parse LLM output and convert it to a Task object. + """ + # Handle different response formats from various LLM clients + if isinstance(raw_response, dict) and "content" in raw_response: + # Compatible with certain client return structures + response = raw_response["content"] + else: + response = str(raw_response) + + # Parse JSON from LLM response + try: + data = self._parse_json_response(response) + except Exception as e: + raise ValueError(f"Failed to parse LLM JSON output: {e}. Raw response: {response}") + + # data: List[Dict[str, str]] + all_generated_tasks = [] + for task in data: + # Extract the generated query from parsed JSON + new_query = task.get("query", "").strip() + if not new_query: + continue + related_doc = task.get("related_doc", "").strip() + # Construct metadata for the new task + new_metadata = {} + if extra_metadata: + new_metadata.update(extra_metadata) + # Store provenance information for traceability + if related_doc: + new_metadata["related_doc"] = related_doc + new_metadata["related_doc_source"] = document.doc_id + new_task = Task( + main_query=new_query, + init_messages=[], + task_id="", # Will be assigned by the system later + env_type=source_task.env_type if source_task else "no_env", + metadata=new_metadata, + ) + all_generated_tasks.append(new_task) + return all_generated_tasks + + def _parse_json_response(self, response: str) -> Optional[Dict[str, Any]]: + """ + Parse LLM response string into JSON. + """ + # Remove Markdown code block markers (```json and ```) if present + response = re.sub(r"^```json|```$", "", response, flags=re.MULTILINE).strip() + return json.loads(response) diff --git a/ajet/data_generator/task_augmentation.py b/ajet/data_generator/task_augmentation.py new file mode 100644 index 00000000..5a4a2de5 --- /dev/null +++ b/ajet/data_generator/task_augmentation.py @@ -0,0 +1,119 @@ +import re +from typing import Any, Dict, Optional, Union + +from ajet.data_generator.base_data_generator import BaseDataGenerator +from ajet.schema.document import Document +from ajet.schema.task import Task + + +class TaskAugmentor(BaseDataGenerator): + """ + Task Augmentation: + - Generate new queries based on reference Query (and optional Document) + """ + + def _build_system_prompt( + self, + source_task: Optional[Task], + document: Optional[Document] = None, + ) -> str: + """ + Build system prompt for task augmentation. + The prompt adapts based on whether a document is provided. + """ + base_prompt = "You are a professional expert in query generation.\n" "Your goal is to generate ONE new user query that:\n" "- Is semantically related to the reference query (similar topic/domain/intent),\n" "- Preserves the original query's style, language, task type, and approximate length,\n" "- Is natural, diverse, and fluent,\n" "- Is NOT a direct copy or minor edit of the original query.\n" + + # Conditional instructions based on document availability + document_instructions = "" + if document is not None and document.content: + document_instructions = "\n" "Document context is provided for reference:\n" "- Infer the document's overall topic or domain (do NOT assume the query is tied to a specific paragraph),\n" "- Ensure the new query is compatible with that overall topic/domain,\n" "- The new query should feel naturally related to the document's theme.\n" "\n" + + # Output format requirements to ensure structured response + output_requirements = "You MUST:\n" "- Avoid copying the original text verbatim,\n" "- Avoid minimal edits such as just changing a few words or reordering phrases,\n" "- Avoid adding explanations or commentary,\n" "- Output ONLY a valid JSON object with a single field 'query'.\n" "\n" "Example output format:\n" '{"query": ""}\n' + + return base_prompt + document_instructions + output_requirements + + def _build_user_prompt( + self, + source_task: Optional[Task], + document: Optional[Document] = None, + ) -> str: + """ + Build user prompt for task augmentation. + Handles both document-present and document-absent scenarios. + """ + if source_task is None or not source_task.main_query: + raise ValueError("TaskAugmentor requires a task for reference.") + + original_query = source_task.main_query + + # Build the reference part (query + optional document) + reference_info = "Reference information:\n" f"[Query]: {original_query}\n" + + # Add document content if provided + doc_part = "" + if document is not None and document.content: + # Only add document-related content if a document is actually provided + reference_info += "[Document]:\n" + doc_part = "Here is the reference document content:\n" f"{document.content}\n" "\n" "Use this document as background knowledge while generating a new query.\n" + + user_prompt = f"{reference_info}" f"{doc_part}" "\n" "Now, generate ONE new user query that is suitable for the same context.\n" "\n" "Important rules:\n" "- Do NOT directly copy or minimally edit the original query.\n" "- Do NOT output explanations, comments, or any extra text.\n" "- Output ONLY a JSON object with the following structure:\n" '{"query": ""}\n' + + return user_prompt + + def _parse_llm_output_to_task( + self, + raw_response: Any, + source_task: Optional[Task], + document: Optional[Document] = None, + extra_metadata: Optional[Dict[str, Any]] = None, + ) -> Task: + """ + Parse LLM output and convert it to a Task object. + """ + # Handle different response formats from various LLM clients + if isinstance(raw_response, dict) and "content" in raw_response: + # Compatible with certain client return structures + response = raw_response["content"] + else: + response = str(raw_response) + + # Parse JSON from LLM response + try: + data = self._parse_json_response(response) + except Exception as e: + raise ValueError(f"Failed to parse LLM JSON output: {e}. Raw response: {response}") + + # Extract the generated query from parsed JSON + new_query = data.get("query", "").strip() + if not new_query: + raise ValueError(f"No 'query' field found in LLM output JSON. Raw JSON: {data}") + + # Construct metadata for the new task + new_metadata = {} + if extra_metadata: + new_metadata.update(extra_metadata) + # Store provenance information for traceability + new_metadata["source_task_id"] = source_task.task_id if source_task else "" + new_metadata["aug_type"] = "task_augmentation" + if document: + new_metadata["source_doc_id"] = document.doc_id + + new_task = Task( + main_query=new_query, + init_messages=[], + task_id="", # Will be assigned by the system later + env_type=source_task.env_type if source_task else "no_env", + metadata=new_metadata, + ) + return new_task + + def _parse_json_response(self, response: str) -> Union[dict, list, str, float, int, bool, None]: + """ + Parse LLM response string into JSON. + """ + # Remove Markdown code block markers (```json and ```) if present + response = re.sub(r"^```json|```$", "", response, flags=re.MULTILINE).strip() + from agentscope._utils._common import _json_loads_with_repair + + return _json_loads_with_repair(response) diff --git a/astune/context_manager/cmt_memory.py b/ajet/default_config/README.md similarity index 100% rename from astune/context_manager/cmt_memory.py rename to ajet/default_config/README.md diff --git a/ajet/default_config/ajet_default.py b/ajet/default_config/ajet_default.py new file mode 100644 index 00000000..9f011335 --- /dev/null +++ b/ajet/default_config/ajet_default.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass, field +from typing import Any, Dict + + +@dataclass +class AjetAlgorithm: + adv_estimator: str = "grpo" + + +@dataclass +class AjetTrainerCommon: + n_gpus_per_node: int = 8 + algorithm: AjetAlgorithm = field(default_factory=AjetAlgorithm) + + +@dataclass +class AjetModel: + path: str = "/path/to/model/such/as/Qwen/Qwen2___5-14B-Instruct" + + +@dataclass +class AjetData: + max_prompt_length: int = 3000 + max_response_length: int = 15000 + train_batch_size: int = 32 + + +@dataclass +class AjetRollout: + user_workflow: str = "tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow" + n_vllm_engine: int = 1 + tensor_model_parallel_size: int = 1 + + +@dataclass +class HuggingfaceDatRepo: + dataset_path: str = "gsm8k" + training_split: str = "train" + validation_split: str = "validation" + + +@dataclass +class AjetTaskReader: + type: str = "huggingface_dat_repo" + huggingface_dat_repo: HuggingfaceDatRepo = field(default_factory=HuggingfaceDatRepo) + + +@dataclass +class AjetDefaultConfig: + project_name: str = "ajet_default_project" + experiment_name: str = "read_yaml_name" + experiment_dir: str = "auto" + backbone: str = "debug" + + model: AjetModel = field(default_factory=AjetModel) + data: AjetData = field(default_factory=AjetData) + rollout: AjetRollout = field(default_factory=AjetRollout) + trainer_common: AjetTrainerCommon = field(default_factory=AjetTrainerCommon) + task_reader: AjetTaskReader = field(default_factory=AjetTaskReader) + + +@dataclass +class Config: + ajet: AjetDefaultConfig = field(default_factory=AjetDefaultConfig) + + @staticmethod + def _to_dict(obj: Any) -> Any: + """Recursively convert dataclass objects to dictionaries.""" + result = {} + for key, value in obj.__dict__.items(): + if hasattr(value, "__dataclass_fields__"): + result[key] = Config._to_dict(value) + else: + result[key] = value + return result + + def to_dict(self) -> Dict[str, Any]: + """Convert the dataclass to a dictionary, including extra attributes.""" + return Config._to_dict(self) + + @staticmethod + def update_from_dict_recursive(config_as_dataclass, config_as_dict: Dict[str, Any]) -> "Config": + # read and assign + for key in config_as_dict.keys(): + target_value = config_as_dict[key] + if isinstance(target_value, dict): + if hasattr(config_as_dataclass, key): + if isinstance(getattr(config_as_dataclass, key), dict): + setattr(config_as_dataclass, key, target_value) + continue + else: + setattr( + config_as_dataclass, + key, + Config.update_from_dict_recursive(getattr(config_as_dataclass, key), target_value), + ) + else: + setattr(config_as_dataclass, key, target_value) + else: + setattr(config_as_dataclass, key, target_value) + return config_as_dataclass diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml new file mode 100644 index 00000000..fb4a6143 --- /dev/null +++ b/ajet/default_config/ajet_default.yaml @@ -0,0 +1,308 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: "ajet_default_project" + experiment_name: "read_yaml_name" + experiment_dir: "auto" # {exp-dir}/{experiment_name} + backbone: debug # `debug` or `trinity` or `verl` + + + model: + # which model should be trained + path: /path/to/model/such/as/Qwen/Qwen2___5-14B-Instruct + + data: + # max number of tokens for prompt + max_prompt_length: 3000 + # max number of tokens for response + max_response_length: 15000 + # how many tasks per training batch + train_batch_size: 32 + # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) + + + rollout: + + # the path to the workflow class + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + + # whether or not to disable all tool calls + force_disable_toolcalls: False + + # maximum number of parallel environments / simulate workers + max_env_worker: 64 + + # step reward gamma (experimental, do not change) + gamma: 1.0 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: + - "nonsense" + # send signal to terminate context tracing when LLM is losing control + agent_madness_termination: True # terminate_after_gone_mad + # punish the LLM when it is detected as lost control + agent_madness_reward: -1.0 + + # max response length in one turn + max_response_length_in_one_turn: 4096 + + # max token length allowed for the model during rollout + max_model_len: 18000 + + multi_turn: + # how many samples should be collected for each task run + max_sample_per_task: 30 + # limit the maximum steps for each task + max_steps: 30 + # the expected steps for each task, used to calculate the training batch size for trinity + expected_steps: 1 + + # TP size for rollout engine + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 1 + + # how many sequences are allowed to be processed in parallel by each vllm engine + max_num_seqs: 10 + + # the usage of infer engine, options: (vllm, sglang) + name: vllm + + # how many times a task should be repeated + num_repeat: 4 + + # rollout kwargs + temperature: 0.9 + top_p: 1.0 + + # validation kwargs + val_kwargs: + # when doing validation, the sample setting when generating response + temperature: 0.0 + top_k: -1 + top_p: 1.0 + do_sample: False + num_repeat: 1 + + + task_reader: + type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + # when `type == jsonl_dataset_file` + jsonl_dataset_file: + training: + file_path: "/path/to/training/data.jsonl" + validation: + file_path: "/path/to/validation/data.jsonl" + + # when `type == env_service` + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: dev + + # when `type == huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: "gsm8k" + training_split: "train" + validation_split: "validation" + + # when `type == data_generation` + data_generation: + document_reader: + document_path: + - 'dataset/document/your-document1.pdf' + - 'dataset/document/your-document2.pdf' + languages: + - eng + chunk_size: 5120 + split_by: "sentence" + cache_enabled: true + query_reader: + type: jsonl_dataset_file + jsonl_dataset_file: + training: + file_path: 'dataset/jsonl/your-queries.jsonl' + task_num: 10 + llm_model: qwen-long + llm_response_length: 8192 + num_workers: 32 + sampling_params: + temperature: 0 + deduplication_filter: + enabled: true + params: + similarity_threshold: 0.8 + db_path: ./.similarity_db + model: text-embedding-v4 + api_key: null # load from the env + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + + + task_judge: + judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' + + # when `judge_type == customized_protocol` + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + + # the helper LLM model used for LLM-AS-Judge + alien_llm_model: qwen3-235b-a22b-instruct-2507 + alien_llm_response_length: 512 + + # when `judge_type == rubrics_auto_grader` + rubrics_auto_grader: + model_name: qwen-max + grader_mode: pointwise + language: en + query_specific_generate_number: 1 + enable_categorization: false + categories_number: 5 + grader_name: "auto_grader" + query_field: main_query + answer_field: final_answer + reference_field: answer + custom_evaluation_prompt: null # dict or PromptTemplate or None + input_data_type: jsonl_dataset_file # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` + jsonl_dataset_file: + training: + file_path: "tutorial/example_rm_auto_grader/rubrics_train.jsonl" + # Pointwise mode settings + min_score: 0 + max_score: 1 + + + + # context tracker protocol is valid ONLY when `use_agentscope_protocol=False` + context_tracker: + + # timeline merging policy used in Context Tracker + timeline_merging_policy: + + # compare_level = "text": relaxed compare with text, more easier to match, at very little cost + # compare_level = "token": strict compare with token, cause less aggressive merging + timeline_compare_level: "text" # options: "text", "token" + + # whether or not to ignore tool calls when comparing steps, default to `True` to make merging more aggressive + ignore_tools: True + + # Fix Retokenization Drift: inconsistencies between training and inference token array + # Related reading: https://github.com/vllm-project/vllm/pull/22587 (note that the implementation is very different) + fix_retokenization_drift: True + + # log tool format check results + log_tool_format_check: False + + # log tool format check results + log_tool_format_error_detail: False + + # detect at which point timeline stop growing linearly and cause a snap during a episode: this will cause additional computation. + detect_timeline_snap: False + + # deprecated + alien_llm_model: qwen3-235b-a22b-instruct-2507 + + # deprecated + alien_llm_response_length: 512 + + + # when backbone is `debug`, debug related configurations + debug: + + # max parallel runners in debug mode + debug_max_parallel: 4 + + # how many task to sample from training set + debug_first_n_tasks: 2 + + # what is the vllm engine port in the background + debug_vllm_port: 18000 + + # what is the seed of the vllm engine in the background + debug_vllm_seed: 12345 + + # what is the TP size in debug mode + debug_tensor_parallel_size: 4 + + + # trainer common configurations + trainer_common: + + # validation before training + val_before_train: False + val_pass_n: 4 + + # save and test frequency (in step) + save_freq: 20 + test_freq: 20 + + # total training epochs + total_epochs: 50 + + nnodes: 1 + n_gpus_per_node: 8 + + # logger selection + logger: swanlab + + # algorithm setting + algorithm: + adv_estimator: grpo + use_kl_in_reward: False + + # number of optimizer.step per big batch + mini_batch_num: 1 + + # verl offload configs + fsdp_config: + param_offload: True + optimizer_offload: True + + # learning rate + optim: + lr: 1e-6 + + # enable KL loss regularization + use_kl_loss: True + + # kl divergence loss coefficient + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl + + # Ulysses specific configs + ulysses_sequence_parallel_size: 1 + + # base directory to save checkpoints + checkpoint_base_dir: ./saved_checkpoints + + # whether to save train/eval trajectories to JSON files + save_trajectory_as_json_file: False + + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_experimental_interchange_server: True + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 'auto' + num_fastapi_process: 2 # 1, 2 or 4 is fine + max_fastapi_threads: 128 # 64 or 128 is fine + max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + + + task_runner: + # submit llm infer submit method + llm_infer_submit_method: "async" # options: "sync", "async" + + # how to wrap the user-defined workflow + wrapper_type: "asyncio-with-gc" + # - wrapper_type: "asyncio-with-gc": safe, with periodic garbage collection to prevent event loop leaks (recommended) + # - wrapper_type: "asyncio": fast, but may cause event loop leak in long run + # - wrapper_type: "multi-processing": safe, but resource consuming + + # when `wrapper_type` is `multi-processing`, the timeout for each task + wrapper_multiprocessing_timeout: 3600 # in seconds + + # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_test: False # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. diff --git a/ajet/default_config/trinity/README.md b/ajet/default_config/trinity/README.md new file mode 100644 index 00000000..0fb7dec2 --- /dev/null +++ b/ajet/default_config/trinity/README.md @@ -0,0 +1,63 @@ +# Trinity Configuration Guide 🛠️ + +## How to Modify Trinity Configuration in AgentJet + +1. 🎯 **Recommended Method**: In most cases, you do not need to directly adjust Trinity parameters. Simply refer to and modify the upper-level `ajet/default_config/ajet_default.yaml` configuration file, and AgentJet will **automatically** handle parameter mapping for you. + +2. ⚙️ **Special Cases**: Some Trinity tuning parameters are not yet mapped in AgentJet. You can refer to Trinity’s documentation and modify them in the following format: + +```yaml +trinity: + algorithm: + algorithm_type: multi_step_grpo +``` + +3. 🚫 **Never Edit**: + - Never edit `ajet/default_config/trinity/trinity_launcher.yaml` + - Never edit `ajet/default_config/trinity/trinity_default.yaml` + +## Configuration Mapping Modification 🔄 + +Some AgentJet configurations overlap with Trinity. +You can configure mappings via the `ajet/default_config/trinity/config_auto_convertion_trinity.jsonc` file. + +## Trinity Hyperparameter Quick Guide 📊 + +Trinity adopts a typical producer (explorer)-consumer (trainer) architecture: +- 🏭 **Producer**: Uses VLLM to generate samples +- 🧠 **Consumer**: Consumes samples to update the model +Both operate on different runtime schedules. + +### Explorer Core Parameters 🔍 + +- `buffer.batchsize`: The minimum unit for reading task data from the dataset. Each read increments the explorer step count by 1. +- `repeat_times`: The number of repetitions per task, also the group size (G) in GRPO. +- `engine_num`: Number of VLLM engines. +- `tensor_parallel_size`: Number of GPUs occupied by each VLLM engine. +- `engine_num * tensor_parallel_size`: Total number of GPUs used by the explorer. +- `eval_interval`: Evaluation interval (in explorer steps). + +### Trainer Core Parameters 🏋️ + +- `buffer.train_batch_size`: The minimum unit consumed from the explorer’s production queue. Each read triggers one optimization step. +- `trainer.save_interval`: Parameter save interval (in trainer steps). + +### Explorer-Trainer Coordination Parameters 🤝 + +- `sync_interval`: Synchronization interval. +- `sync_offset`: Synchronization offset. +- `sync_style`: Synchronization method. + +### Runtime Instance Analysis 📈 + +**Supply Side**: Explorer runs 89 steps × batch size (8) × repeat times (4) × tasks per round (≈1) = 2,848 samples. + +meanwhile + +**Demand Side**: Trainer runs 10 steps × training batch size (264) = 2,640 samples. + +### Training Memory Control 💾 + +Same as VERL, control training memory with the following parameters: +- `trainer.max_token_len_per_gpu` +- `ulysses_sequence_parallel_size` diff --git a/ajet/default_config/trinity/README_ZH.md b/ajet/default_config/trinity/README_ZH.md new file mode 100644 index 00000000..1f3a02f0 --- /dev/null +++ b/ajet/default_config/trinity/README_ZH.md @@ -0,0 +1,64 @@ +# Trinity 配置指南 🛠️ + +## 如何修改 Trinity 配置 + +1. 🎯 **推荐方式**:在大多数情况下,您无需直接调整 Trinity 参数,只需参考并修改上层的 `ajet/default_config/ajet_default.yaml` 配置文件即可, + AgentJet 会**自动**帮您完成参数映射。 + +2. ⚙️ **特殊情况**:部分 Trinity 调优参数目前尚未在 AgentJet 中建立映射,您可以参考 Trinity 的文档,然后通过以下形式进行修改: + +```yaml +trinity: + algorithm: + algorithm_type: multi_step_grpo +``` + +3. 🚫 **永远不要编辑**: + - 永远不要编辑 `ajet/default_config/trinity/trinity_launcher.yaml` + - 永远不要编辑 `ajet/default_config/trinity/trinity_default.yaml` + +## 配置映射修改 🔄 + +某些 AgentJet 配置与 Trinity 存在重叠, +可通过 `ajet/default_config/trinity/config_auto_convertion_trinity.jsonc` 文件进行映射配置 + +## Trinity 超参数简明指南 📊 + +Trinity 采用典型的生产者(探索器)-消费者(训练器)架构: +- 🏭 **生产者**:使用 VLLM 生成样本 +- 🧠 **消费者**:消耗样本更新模型 +两者具有不同的运行时序 + +### 探索器核心参数 🔍 + +- `buffer.batchsize`:从数据集读取任务数据的最小单位,每次读取视为探索器步数 +1 +- `repeat_times`:每个任务重复次数,也是 GRPO 中 G(分组)的大小 +- `engine_num`:VLLM 引擎数量 +- `tensor_parallel_size`:每个 VLLM 引擎占用的显卡数量 +- `engine_num * tensor_parallel_size`:探索器使用的总显卡数量 +- `eval_interval`:评估间隔(以探索器步数为单位) + +### 训练器核心参数 🏋️ + +- `buffer.train_batch_size`:从探索器生产队列中消费的最小单位,每次读取后执行一次优化步骤 +- `trainer.save_interval`:参数保存间隔(以训练器步数为单位) + +### 探索器-训练器协同参数 🤝 + +- `sync_interval`:同步间隔 +- `sync_offset`:同步偏移 +- `sync_style`:同步方式 + +### 运行实例分析 📈 + +**供给端**:探索器运行 89 步 × 批次大小(8) × 重复次数(4) × 每轮任务(≈1) = 2848 个样本 + +与此同时,在另一边 + +**消费端**:训练器运行 10 步 × 训练批次大小(264) = 2640 个样本 + +### 训练显存控制 💾 + +与 VERL 相同,通过以下参数控制训练显存: +- `trainer.max_token_len_per_gpu` +- `ulysses_sequence_parallel_size` diff --git a/ajet/default_config/trinity/config_auto_convertion_trinity.jsonc b/ajet/default_config/trinity/config_auto_convertion_trinity.jsonc new file mode 100644 index 00000000..02b8f812 --- /dev/null +++ b/ajet/default_config/trinity/config_auto_convertion_trinity.jsonc @@ -0,0 +1,26 @@ +{ + "ajet.trainer_common.nnodes": "cluster.node_num", + "ajet.trainer_common.n_gpus_per_node": "cluster.gpu_per_node", + "ajet.trainer_common.total_epochs": "buffer.total_epochs", + "ajet.trainer_common.test_freq": "explorer.eval_interval", + "ajet.trainer_common.save_freq": "trainer.save_interval", + "ajet.trainer_common.ulysses_sequence_parallel_size": "trainer.ulysses_sequence_parallel_size", + "ajet.trainer_common.optim.lr": "algorithm.optimizer.lr", + "ajet.trainer_common.logger": "monitor.monitor_type", + "ajet.trainer_common.checkpoint_base_dir": "checkpoint_root_dir", + "ajet.rollout.n_vllm_engine": "explorer.rollout_model.engine_num", + "ajet.model.path": "model.model_path", + "ajet.project_name": "project", + "ajet.experiment_name": "name", + "ajet.data.max_response_length": "model.max_response_tokens", + "ajet.rollout.num_repeat": "algorithm.repeat_times", + "ajet.rollout.tensor_model_parallel_size": "explorer.rollout_model.tensor_parallel_size", + "ajet.rollout.max_model_len": [ + "model.max_model_len", + "trainer.max_token_len_per_gpu" + ], + + "ajet.data.train_batch_size": "buffer.batch_size", + "(ajet.data.train_batch_size * ajet.rollout.num_repeat * ajet.rollout.multi_turn.expected_steps)": "buffer.train_batch_size", + "(min(ajet.rollout.max_env_worker, 128) // ajet.rollout.n_vllm_engine)": "explorer.runner_per_model" +} diff --git a/ajet/default_config/trinity/trinity_default.yaml b/ajet/default_config/trinity/trinity_default.yaml new file mode 100644 index 00000000..53be9a2a --- /dev/null +++ b/ajet/default_config/trinity/trinity_default.yaml @@ -0,0 +1,78 @@ +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD + +trinity: + algorithm: + algorithm_type: multi_step_grpo + policy_loss_fn_args: + fallback_to_policy_gradient: true + clip_range: 0.2 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + enable_sequence_masking: false + delta_sequence_masking: 0.1 + + buffer: + # Explorer-side mini-batch for reading tasks from the dataset (producer input size) + batch_size: 32 + explorer_input: + eval_tasksets: [] + taskset: + default_workflow_type: ajet_workflow + format: + prompt_key: question + response_key: answer + name: "" + path: http://localhost:8080 + rollout_args: + temperature: 1.0 + split: train + storage_type: ajet + subset_name: "" + total_epochs: 1000 + # Trainer-side mini-batch consumed from the explorer queue per optimization step (consumer input size) + train_batch_size: 36 + trainer_input: + experience_buffer: + max_read_timeout: 18000 + name: "ajet_experience_buffer" + storage_type: queue + replay_buffer: + enable: false + priority_fn: linear_decay + reuse_cooldown_time: null + priority_fn_args: + decay: 2.0 + checkpoint_root_dir: ./trinity_checkpoints + # Explorer = producer (typically VLLM), generates samples + explorer: + max_repeat_times_per_runner: 1 + max_timeout: 3600 + rollout_model: + dtype: bfloat16 + enable_auto_tool_choice: true + enable_history: true + enable_openai_api: true + enable_prefix_caching: false + enable_thinking: false + enforce_eager: false + seed: 42 + tool_call_parser: hermes + # runner_state_report_interval: 30 + monitor: + monitor_type: swanlab + name: dummy_name + project: dummy_project + synchronizer: + sync_interval: 1 + sync_method: nccl + sync_style: dynamic_by_explorer + sync_timeout: 7200 + # Trainer = consumer, updates model parameters using explorer outputs + trainer: + grad_clip: 1.0 + use_dynamic_bsz: true + cluster: + gpu_per_node: 8 + node_num: 1 diff --git a/ajet/default_config/trinity/trinity_launch.yaml b/ajet/default_config/trinity/trinity_launch.yaml new file mode 100644 index 00000000..8b6f679d --- /dev/null +++ b/ajet/default_config/trinity/trinity_launch.yaml @@ -0,0 +1,6 @@ +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD + +algorithm: + algorithm_type: multi_step_grpo diff --git a/ajet/default_config/verl/config_auto_convertion_verl.jsonc b/ajet/default_config/verl/config_auto_convertion_verl.jsonc new file mode 100644 index 00000000..378fd112 --- /dev/null +++ b/ajet/default_config/verl/config_auto_convertion_verl.jsonc @@ -0,0 +1,39 @@ +{ + "ajet.trainer_common.total_epochs": "trainer.total_epochs", + + "ajet.trainer_common.val_before_train": "trainer.val_before_train", + "ajet.trainer_common.n_gpus_per_node": "trainer.n_gpus_per_node", + "ajet.trainer_common.nnodes": "trainer.nnodes", + "ajet.trainer_common.logger": "trainer.logger", + "ajet.trainer_common.checkpoint_base_dir": "trainer.checkpoint_base_dir", + "ajet.trainer_common.algorithm.adv_estimator": "algorithm.adv_estimator", + "ajet.trainer_common.algorithm.use_kl_in_reward": "algorithm.use_kl_in_reward", + "ajet.trainer_common.mini_batch_num": "actor_rollout_ref.actor.override_ppo_mini_batch_num", + "ajet.trainer_common.fsdp_config": "actor_rollout_ref.actor.fsdp_config", + "ajet.trainer_common.optim": "actor_rollout_ref.actor.optim", + "ajet.trainer_common.use_kl_loss": "actor_rollout_ref.actor.use_kl_loss", + "ajet.trainer_common.kl_loss_coef": "actor_rollout_ref.actor.kl_loss_coef", + "ajet.trainer_common.kl_loss_type": "actor_rollout_ref.actor.kl_loss_type", + "ajet.trainer_common.ulysses_sequence_parallel_size": "actor_rollout_ref.actor.ulysses_sequence_parallel_size", + + "ajet.trainer_common.save_freq": "trainer.save_freq", + "ajet.trainer_common.test_freq": "trainer.test_freq", + + "ajet.data.max_prompt_length": "data.max_prompt_length", + "ajet.data.max_response_length": "data.max_response_length", + "ajet.data.train_batch_size": "data.train_batch_size", + + "ajet.rollout.max_model_len": [ + "actor_rollout_ref.rollout.max_model_len", + "actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu", + "actor_rollout_ref.actor.ppo_max_token_len_per_gpu", + "actor_rollout_ref.ref.log_prob_max_token_len_per_gpu" + ], + + "ajet.rollout.multi_turn": "actor_rollout_ref.rollout.multi_turn", + "ajet.rollout.val_kwargs": "actor_rollout_ref.rollout.val_kwargs", + + "ajet.model.path": "actor_rollout_ref.model.path", + "ajet.project_name": "trainer.project_name", + "ajet.experiment_name": "trainer.experiment_name" +} diff --git a/ajet/default_config/verl/verl_default.yaml b/ajet/default_config/verl/verl_default.yaml new file mode 100644 index 00000000..51406396 --- /dev/null +++ b/ajet/default_config/verl/verl_default.yaml @@ -0,0 +1,433 @@ +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD +# DO NOT EDIT: THIS FILE IS READ ONLY and ALWAYS FIXED, EDIT `ajet/default_config/ajet_default.yaml` INSTEAD + +ajet: + rollout: + step_skip_action: 0 + submit_oversample_multiplier: 1.5 + enable_oversample: False + +actor_rollout_ref: + actor: + _target_: verl.workers.config.FSDPActorConfig + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + async_save: false + load_contents: + - model + - optimizer + - extra + save_contents: + - model + - optimizer + - extra + clip_ratio: 0.2 + clip_ratio_c: 3.0 + clip_ratio_high: 0.2 + clip_ratio_low: 0.2 + entropy_checkpointing: false + entropy_coeff: 0 + entropy_from_logits_with_chunking: false + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + forward_prefetch: false + fsdp_size: -1 + offload_policy: false + optimizer_offload: true + param_offload: true + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + grad_clip: 1.0 + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl + loss_agg_mode: seq-mean-token-mean + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + lr: 1.0e-06 + lr_warmup_steps: -1 + lr_warmup_steps_ratio: 0.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + total_training_steps: -1 + warmup_style: constant + weight_decay: 0.01 + override_ppo_mini_batch_num: 1 + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + clip_cov_lb: 1.0 + clip_cov_ratio: 0.0002 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + loss_mode: vanilla + ppo_kl_coef: 0.1 + ppo_epochs: 1 + ppo_max_token_len_per_gpu: 13000 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: 1 + ppo_mini_batch_size: 16 + shuffle: false + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true + use_fused_kernels: false + use_kl_loss: true + use_remove_padding: true + use_torch_compile: true + hybrid_engine: true + model: + custom_chat_template: null + enable_activation_offload: false + enable_gradient_checkpointing: true + exclude_modules: null + external_lib: null + fused_kernel_options: + impl_backend: torch + lora_alpha: 16 + lora_rank: 0 + override_config: {} + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + target_modules: all-linear + trust_remote_code: false + use_fused_kernels: false + use_liger: false + use_remove_padding: true + use_shm: false + nccl_timeout: 600 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + all_ranks: false + discrete: false + ranks: [] + ref: + entropy_checkpointing: false + entropy_from_logits_with_chunking: false + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + forward_prefetch: false + param_offload: true + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + log_prob_max_token_len_per_gpu: 13000 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: true + model: null + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true + use_torch_compile: true + rollout: + agent: + agent_loop_config_path: null + custom_async_server: + name: null + path: null + num_workers: 8 + calculate_log_probs: false + cudagraph_capture_sizes: null + custom_dataflow_cls: + name: '' + path: '' + disable_log_stats: true + do_sample: true + dtype: bfloat16 + enable_chunked_prefill: true + enforce_eager: true + engine_kwargs: + sglang: + attention_backend: null + vllm: + disable_mm_preprocessor_cache: false + swap_space: null + free_cache_engine: true + gamma: 1.0 + gpu_memory_utilization: 0.9 + ignore_eos: false + layered_summon: false + load_format: dummy_dtensor + log_prob_max_token_len_per_gpu: 13000 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: true + max_env_worker: 64 + max_model_len: 13000 + max_num_batched_tokens: 8192 + max_num_seqs: 10 + mode: async + multi_stage_wake_up: false + multi_turn: + enable: true + format: hermes + interaction_config_path: null + max_assistant_turns: null + max_parallel_calls: 1 + max_sample_per_task: 4 + expected_steps: 1 + max_steps: 30 + max_tool_response_length: 256 + max_user_turns: null + tokenization_sanity_check_mode: strict + tool_config_path: null + tool_response_truncate_side: middle + use_inference_chat_template: false + n: 1 + name: vllm + ppo_micro_batch_size_per_gpu: 1 + prompt_length: 3000 + response_length: 10000 + skip_dump_dir: /tmp/rollout_dump + skip_rollout: false + temperature: 0.9 + tensor_model_parallel_size: 1 + top_k: -1 + top_p: 1.0 + trace: + backend: null + token2text: false + update_weights_bucket_megabytes: 512 + val_kwargs: + do_sample: false + n: 1 + num_repeat: 1 + temperature: 0.0 + top_k: -1 + top_p: 1.0 + + +algorithm: + _target_: verl.trainer.config.AlgoConfig + adv_estimator: grpo + gamma: 1.0 + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + horizon: 10000 + kl_coef: 0.001 + target_kl: 0.1 + type: fixed + kl_penalty: kl + lam: 1.0 + norm_adv_by_std_in_grpo: true + pf_ppo: + reweight_method: pow + weight_pow: 2.0 + use_kl_in_reward: false + use_pf_ppo: false + + +critic: + _target_: verl.workers.config.FSDPCriticConfig + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + async_save: false + load_contents: + - model + - optimizer + - extra + save_contents: + - model + - optimizer + - extra + cliprange_value: 0.5 + enable: false + forward_max_token_len_per_gpu: 32768 + forward_micro_batch_size: null + forward_micro_batch_size_per_gpu: null + grad_clip: 1.0 + loss_agg_mode: seq-mean-token-mean + model: + _target_: verl.workers.config.FSDPCriticModelCfg + enable_activation_offload: false + enable_gradient_checkpointing: true + external_lib: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + forward_prefetch: false + fsdp_size: -1 + offload_policy: false + optimizer_offload: false + param_offload: false + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + lora_alpha: 16 + lora_rank: 0 + override_config: {} + path: ~/models/deepseek-llm-7b-chat + target_modules: all-linear + tokenizer_path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + trust_remote_code: false + use_remove_padding: false + use_shm: false + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + lr: 1.0e-05 + lr_warmup_steps: -1 + lr_warmup_steps_ratio: 0.0 + min_lr_ratio: null + total_training_steps: -1 + warmup_style: constant + weight_decay: 0.01 + ppo_epochs: 1 + ppo_max_token_len_per_gpu: 32768 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + ppo_mini_batch_size: 16 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + all_ranks: false + discrete: false + ranks: [] + rollout_n: 1 + shuffle: false + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true + + +custom_reward_function: + name: compute_score + path: null + + +data: + custom_cls: + name: null + path: null + datagen: + name: null + path: null + dataloader_num_workers: 8 + fast_eval: true + filter_overlong_prompts: true + filter_overlong_prompts_workers: 1 + image_key: images + max_prompt_length: 3000 + max_response_length: 10000 + prompt_key: prompt + return_full_prompt: false + return_multi_modal_inputs: true + return_raw_chat: true + return_raw_input_ids: false + reward_fn_key: data_source + sampler: + class_name: null + class_path: null + shuffle: true + tokenizer: null + train_batch_size: 264 + train_files: ~/data/rlhf/gsm8k/train.parquet + truncation: error + trust_remote_code: false + use_shm: false + val_batch_size: 100000000000 + val_files: ~/data/rlhf/gsm8k/test.parquet + validation_shuffle: false + video_key: videos + seed: 42 + + +ray_init: + num_cpus: null + timeline_json_file: null + + +reward_model: + enable: false + forward_max_token_len_per_gpu: 32768 + launch_reward_fn_async: false + max_length: null + micro_batch_size: null + micro_batch_size_per_gpu: null + model: + external_lib: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + forward_prefetch: false + fsdp_size: -1 + param_offload: false + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + input_tokenizer: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + trust_remote_code: false + use_fused_kernels: false + use_remove_padding: false + use_shm: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + all_ranks: false + discrete: false + ranks: [] + reward_manager: naive + sandbox_fusion: + max_concurrent: 64 + memory_limit_mb: 1024 + url: null + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true + + +trainer: + balance_batch: true + controller_nsight_options: + cuda-graph-trace: graph + cuda-memory-usage: 'true' + trace: cuda,nvtx,cublas,ucx + critic_warmup: 0 + default_hdfs_dir: null + checkpoint_base_dir: ./saved_checkpoints + default_local_dir: ${trainer.checkpoint_base_dir}/${trainer.project_name}/${trainer.experiment_name} + del_local_ckpt_after_load: false + device: cuda + esi_redundant_time: 0 + experiment_name: read_yaml_name + hfmodelpath: '' + log_val_generations: 0 + logger: + - console + - swanlab + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + n_gpus_per_node: 8 + nnodes: 1 + npu_profile: + options: + analysis: true + level: level1 + record_shapes: false + roles: + - all + save_path: ./profiler_data + with_cpu: true + with_memory: false + with_module: false + with_npu: true + with_stack: false + profile_continuous_steps: false + profile_steps: null + project_name: project_name_placeholder + ray_wait_register_center_timeout: 300 + resume_from_path: null + resume_mode: auto + rollout_data_dir: null + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + total_training_steps: null + use_legacy_worker_impl: auto + val_before_train: false + val_only: false + val_pass_n: 4 + validation_data_dir: null + worker_nsight_options: + capture-range: cudaProfilerApi + capture-range-end: null + cuda-graph-trace: graph + cuda-memory-usage: 'true' + kill: none + trace: cuda,nvtx,cublas,ucx diff --git a/ajet/launcher.py b/ajet/launcher.py new file mode 100644 index 00000000..963e1338 --- /dev/null +++ b/ajet/launcher.py @@ -0,0 +1,319 @@ +import argparse +import os +import subprocess + +from dotenv import load_dotenv +from loguru import logger + +from ajet.utils.cleaner import fast_kill_by_keyword_bash +from ajet.utils.config_utils import prepare_experiment_config +from ajet.utils.launch_utils import ( + execute_training_process, + launch_logview, + set_loguru_default_color, + start_ray_service, +) +from ajet.utils.pty import pty_launch + +set_loguru_default_color() +load_dotenv(override=False) + + +def parse_args(): + parser = argparse.ArgumentParser(description="AgentJet Launcher") + parser.add_argument( + "--backbone", + type=str, + default="verl", + required=False, + help="verl or trinity or debug", + ) + parser.add_argument( + "--conf", + type=str, + default="", + required=False, + help="Path to configuration file", + ) + parser.add_argument( + "--exp-dir", + type=str, + default="saved_experiments", + required=False, + help="Path to experiment directory", + ) + parser.add_argument( + "--debug", + "--db", + type=str, + default="", + required=False, + help="Path to configuration file", + ) + + parser.add_argument("--with-ray", action="store_true", default=False, help="Launch ray") + parser.add_argument("--with-ray-cluster", action="store_true", default=False, help="Launch ray") + parser.add_argument( + "--with-appworld", + action="store_true", + default=False, + help="Launch appworld", + ) + parser.add_argument( + "--with-finworld", + action="store_true", + default=False, + help="Launch finworld", + ) + parser.add_argument( + "--with-webshop", + action="store_true", + default=False, + help="Launch webshop", + ) + parser.add_argument("--with-bfcl", action="store_true", default=False, help="Launch bfcl") + parser.add_argument( + "--with-logview", + action="store_true", + default=False, + help="Launch logview", + ) + parser.add_argument( + "--with-crafters", + action="store_true", + default=False, + help="Launch Crafters Env Simulation", + ) + parser.add_argument("--reboot", action="store_true", default=False, help="reboot flag") + parser.add_argument("--skip-check-avail-gpu", action="store_true", default=False, help="Skip GPU availability check") + parser.add_argument( + "--kill", + type=str, + default="", + required=False, + help="list of keywords for killing processes", + ) + parser.add_argument( + "--autokill", + action="store_true", + default=False, + help="Kill system processes (ray + vllm + python) that may block the current experiment", + ) + return parser.parse_args() + + +def check_debugpy_version(): + try: + import debugpy + except ImportError: + raise RuntimeError("Module 'debugpy>=1.8.0' cannot be loaded. " "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. " "Install this module using 'pip install debugpy>=1.8.0'") + version = getattr(debugpy, "__version__", "0.0.0") + from packaging import version as packaging_version + + if packaging_version.parse(version) < packaging_version.parse("1.8.0"): + raise RuntimeError(f"debugpy version {version} is too old. " "Ray Debugpy Debugger requires 'debugpy>=1.8.0'. " "Upgrade using 'pip install debugpy>=1.8.0'") + logger.info(f"✓ debugpy version {version} meets requirement (>=1.8.0)") + + +def check_avail_gpu(min_free_ratio: float = 0.95): + """ + Ensure there is at least one GPU and all GPUs have >= min_free_ratio free memory. + + Uses `nvidia-smi` to query total and used memory for each GPU. + Raises RuntimeError if no GPU is found or any GPU violates the free ratio threshold. + """ + try: + # Query GPU memory via nvidia-smi; output in MiB + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=name,memory.total,memory.used", + "--format=csv,noheader,nounits", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + except FileNotFoundError: + raise RuntimeError("nvidia-smi not found. NVIDIA drivers/GPU may be unavailable.") + + if result.returncode != 0: + raise RuntimeError(f"Failed to query GPUs via nvidia-smi: {result.stderr.strip()}") + + lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] + if not lines: + raise RuntimeError("No GPUs detected by nvidia-smi.") + + violations = [] + for idx, line in enumerate(lines): + # Expected format: ", , " + parts = [p.strip() for p in line.split(",")] + if len(parts) < 3: + violations.append((idx, "parse-error", line)) + continue + name, total_str, used_str = parts[0], parts[1], parts[2] + try: + total = float(total_str) + used = float(used_str) + except ValueError: + violations.append((idx, "parse-error", line)) + continue + free = max(total - used, 0.0) + free_ratio = free / total if total > 0 else 0.0 + logger.info(f"GPU {idx} ({name}): total={total:.0f} MiB, used={used:.0f} MiB, free_ratio={free_ratio:.3f}") + if free_ratio < min_free_ratio: + violations.append((idx, name, f"free_ratio={free_ratio:.3f} < {min_free_ratio:.3f}")) + + if violations: + details = "; ".join([f"GPU {i} ({n}): {msg}" for i, n, msg in violations]) + raise RuntimeError("GPU memory check failed: all GPUs must have >= " f"{int(min_free_ratio*100)}% free. Violations: {details}") + logger.info(f"✓ GPU check passed: {len(lines)} GPUs, all >= {int(min_free_ratio*100)}% free memory") + + +def get_backbone_target(backbone): + """ + Determine the appropriate backbone target module based on the backbone name. + + Args: + backbone (str): The backbone name (e.g., "verl", "debug", "trinity") + + Returns: + str: The full module path for the specified backbone + """ + backbone_target = "ajet.backbone.main_verl" # Default to trinity + if backbone == "verl": + backbone_target = "ajet.backbone.main_verl" + if backbone == "debug": + backbone_target = "ajet.backbone.main_vllm" + if backbone == "trinity": + backbone_target = "ajet.backbone.main_trinity" + return backbone_target + + +def setup_environment_vars(args, exp_config, main_yaml_fp): + """ + Configure environment variables based on command line arguments. + + Args: + args: Command line arguments + exp_config: Experiment configuration dictionary + main_yaml_fp: Path to main YAML configuration file + + Returns: + dict: Configured environment variables dictionary + """ + env = os.environ.copy() + if args.debug: + env["RAY_DEBUG_POST_MORTEM"] = "1" + env["DEBUG_TAGS"] = args.debug + env["RAY_record_task_actor_creation_sites"] = "true" + # assert exp_config["ajet"]["rollout"]["max_env_worker"] <= 4, "parallel worker too many for debugging mode" # type: ignore + if exp_config["ajet"]["rollout"]["max_env_worker"] > 1: # type: ignore + exp_config["ajet"]["rollout"]["max_env_worker"] = 1 + logger.warning("For debugging mode, max_env_worker is set to 1 to facilitate debugging.") + logger.warning("Debug mode is ON") + else: + logger.warning("Debug mode is OFF") + # if args.conf: + # assert exp_config["ajet"]["rollout"]["max_env_worker"] > 4, "parallel worker too few" # type: ignore + if args.backbone == "trinity": + env["AJET_CONFIG_REDIRECT"] = main_yaml_fp # type: ignore + if args.backbone == "debug": + env["AJET_DEBUG"] = "1" # type: ignore + return env, exp_config + + +def check_model_file_exists(exp_config): + model_path = exp_config["ajet"]["model"]["path"] + # if model_path has more than 2 '/', we consider it as a dir path + if model_path.count("/") > 2: + assert os.path.exists(model_path), f"Model path {model_path} does not exist. Please check your configuration." + + +def main(): + args = parse_args() + + # Enforce GPU availability and free memory threshold before proceeding + if not args.skip_check_avail_gpu: + if (args.backbone != "debug") and (not args.kill) and (not args.autokill): + check_avail_gpu(min_free_ratio=0.95) + + if args.autokill: + args.kill = "ray|vllm|VLLM|python" + + # Handle kill-keywords argument if provided + if args.kill: + logger.info(f"Killing processes matching keywords: {args.kill}") + for keyword in args.kill.split("|"): + logger.info(f"Killing processes matching keyword: {keyword}") + killed_pids = fast_kill_by_keyword_bash(keyword) + if killed_pids: + logger.success(f"Successfully killed processes with PIDs: {killed_pids}") + else: + logger.warning(f"No processes found matching keyword: {keyword}") + if not args.conf: + return + + # Initialize variables with default values to avoid "possibly unbound" errors + main_yaml_fp = None + exe_exp_base = None + exp_name = None + + # switch backbone target + backbone_target = get_backbone_target(args.backbone) + + exp_config = None + exp_dir = args.exp_dir or "saved_experiments" + if args.conf: + yaml_path = args.conf + ( + main_yaml_fp, + exe_exp_base, + exp_name, + exp_config, + ) = prepare_experiment_config(yaml_path, exp_dir, args.backbone) + + env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) + if args.with_ray: + assert not args.with_ray_cluster, "Cannot use both --with-ray and --with-ray-cluster simultaneously." + start_ray_service(args, env) + + if args.with_appworld: + pty_launch("appworld") + + if args.with_finworld: + pty_launch("finworld") + + if args.with_crafters: + pty_launch("crafters") + + if args.with_webshop: + pty_launch("webshop") + + if args.with_bfcl: + pty_launch("bfcl") + + if args.with_logview: + launch_logview(exp_name) + + if args.with_ray_cluster: + assert not args.with_ray, "Cannot use both --with-ray and --with-ray-cluster simultaneously." + start_ray_service(args, env, cluster=True) + + if args.conf and main_yaml_fp and exe_exp_base and exp_config: + check_model_file_exists(exp_config) + execute_training_process( + args, + backbone_target, + main_yaml_fp, + exe_exp_base, + main_yaml_fp, + env, + exp_config, + ) + + +if __name__ == "__main__": + check_debugpy_version() + main() diff --git a/astune/utils/__init__.py b/ajet/schema/__init__.py similarity index 100% rename from astune/utils/__init__.py rename to ajet/schema/__init__.py diff --git a/ajet/schema/convertion.py b/ajet/schema/convertion.py new file mode 100644 index 00000000..076f629f --- /dev/null +++ b/ajet/schema/convertion.py @@ -0,0 +1,105 @@ +import time +from openai.types.chat.chat_completion import ChatCompletion, Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from agentscope.model import ChatResponse as AgentScopeChatResponse +from openai.types.completion_usage import CompletionUsage +from typing import Any, Callable, Dict, List, Literal, Type, Union +from agentscope.message import TextBlock, ToolUseBlock +from agentscope._utils._common import _json_loads_with_repair +from pydantic import BaseModel +from agentscope.model import ChatResponse + + +def convert_llm_proxy_response_to_oai_response(llm_proxy_response): + # Create the chat completion message + message = ChatCompletionMessage( + role=llm_proxy_response.get("role", "assistant"), + content=llm_proxy_response.get("content", ""), + tool_calls=llm_proxy_response.get("tool_calls", []), + ) + + # Create a choice object + choice = Choice( + index=0, + message=message, + finish_reason="stop", + ) + + # Calculate token usage if tokens are available + usage = None + if "tokens" in llm_proxy_response and llm_proxy_response["tokens"]: + completion_tokens = len(llm_proxy_response["tokens"]) + usage = CompletionUsage( + prompt_tokens=0, # Not available in llm_proxy_response + completion_tokens=completion_tokens, + total_tokens=completion_tokens, + ) + + return ChatCompletion( + id=llm_proxy_response.get("request_id", "chatcmpl-default"), + choices=[choice], + created=int(time.time()), + model="unknown", # Model name not provided in llm_proxy_response + object="chat.completion", + usage=usage, + ) + + +# modified from AgentScope's DashScopeChatModule +def convert_llm_proxy_response_to_agentscope_response( + message, + structured_model: Type[BaseModel] | None = None, +) -> AgentScopeChatResponse: # type: ignore + content_blocks: List[TextBlock | ToolUseBlock] = [] + content = message.get("content") + metadata: dict | None = None + + if content not in [ + None, + "", + [], + ]: + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and "text" in item: + content_blocks.append( + TextBlock( + type="text", + text=item["text"], + ), + ) + else: + content_blocks.append( + TextBlock( + type="text", + text=content, + ), + ) + + if message.get("tool_calls"): + for tool_call in message["tool_calls"]: + input_ = _json_loads_with_repair( + tool_call["function"].get( + "arguments", + "{}", + ) + or "{}", + ) + content_blocks.append( + ToolUseBlock( + type="tool_use", + name=tool_call["function"]["name"], + input=input_, # type: ignore + id=tool_call["id"], + ), + ) + + if structured_model: + metadata = input_ # type: ignore + + parsed_response = AgentScopeChatResponse( + content=content_blocks, + metadata=metadata, + ) + + return parsed_response diff --git a/ajet/schema/document.py b/ajet/schema/document.py new file mode 100644 index 00000000..ac2cb1f8 --- /dev/null +++ b/ajet/schema/document.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel, Field + + +class Document(BaseModel): + doc_id: str = Field(default="") + content: str = Field(default="") + metadata: dict = Field(default_factory=dict) diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py new file mode 100644 index 00000000..86f51d23 --- /dev/null +++ b/ajet/schema/extended_msg.py @@ -0,0 +1,335 @@ +import uuid +from typing import List + +from loguru import logger +from transformers.tokenization_utils import PreTrainedTokenizer + +from ajet.utils.tokenizer import ajet_apply_chat_template + +# import numpy as np +# INVALID_LOG_PROB_VALUE = np.inf # when debuging, set to np.inf, if anything goes wrong, we can sense that immediately +INVALID_LOG_PROB_VALUE = 0 # normally, set to 0 is ok +NEED_TRAIN_AUTHORS = ["llm"] +NON_TRAIN_AUTHORS = [ + "env", + "initialization", + "user", + "memory", + "llm(do_not_train)", +] +DUMMY_MSG = [{"role": "assistant", "content": "dummy text"}] + + +def find_sublist_indices(large_list, small_list, reverse=False): + small_len = len(small_list) + if reverse: + for i in reversed(range(len(large_list) - small_len + 1)): + if large_list[i : i + small_len] == small_list: + return i + for i in range(len(large_list) - small_len + 1): + if large_list[i : i + small_len] == small_list: + return i + return -1 + + +def blackout_specific_token_ids_first_encounter(mask, arr, token_ids): + index = find_sublist_indices(arr, token_ids, reverse=False) + if index >= 0: + for i in range(index, index + len(token_ids)): + mask[i] = 0 + return mask + + +def blackout_everything_after_eos_but_keep_eos(mask, token_arr, eos_token_id): + eos_position = token_arr.index(eos_token_id) if eos_token_id in token_arr else -1 + if eos_position != -1: + for i in range(eos_position + 1, len(mask)): + mask[i] = 0 + return mask + + +def blackout_everything_after_eos_including_eos(mask, token_arr, eos_token_id): + eos_position = token_arr.index(eos_token_id) if eos_token_id in token_arr else -1 + if eos_position != -1: + for i in range(eos_position, len(mask)): + mask[i] = 0 + return mask + + +class ExtendedMessage: + def __init__( + self, + author, + role="assistant", + content="", + token_arr=[], + token_begin_index=-1, + token_end_index=-1, + clip=False, + clip_token_limit=8192, + tokenizer: PreTrainedTokenizer = None, # type: ignore + token_generator="manual", + build_from_uuid="", + tools=[], + tool_calls=[], + tool_call_id="", + token_logprob_arr=[], + name="", # preserved field, not used currently + first_message=False, + ): + self.author = author + self.role = role + self.content = content + self.token_arr = token_arr + self.token_logprob_arr = token_logprob_arr + self.token_begin_index = token_begin_index + self.token_end_index = token_end_index + self.invalid_log_prob_value = INVALID_LOG_PROB_VALUE + self._content_for_future = "" + self._info = "" + self.clip = clip + self.tools = tools + self.tool_calls = tool_calls + self.tool_call_id = tool_call_id + self.name = name # preserved field, not used currently + if not isinstance(self.tool_calls, list): + # agent scope sometimes gives weird type for tool_calls, which is against OpenAI schema + self.tool_calls = list(self.tool_calls) + self.uuid = uuid.uuid4().hex + self.build_from_uuid = build_from_uuid + self.first_message = first_message + self.manual_loss_mask_override = [] + self.lack_normal_eos = False + + if not clip: + self.generate_content_for_future(tokenizer=None, clip=False) + else: + self.generate_content_for_future( + tokenizer=tokenizer, + clip=True, + clip_token_limit=clip_token_limit, + ) + self.eos_token_id = tokenizer.eos_token_id + + if token_generator == "auto": + self.token_arr = self.auto_tokenize( + tokenizer=tokenizer, + tools=tools, + ) + + def auto_tokenize(self, tokenizer, tools): + if (not self.first_message) and (self.role == "system"): + raise ValueError("The system message is usually the first message, check program bugs.") + elif (self.first_message) and (self.role != "system"): + raise ValueError("The first message is supposed to be the system message, check program bugs, or remove this warning.") + if not self.first_message: + self.token_arr = self.auto_tokenize_non_first_message(tokenizer=tokenizer, tools=tools) + else: + auto_tokenize_target = { + "role": self.role, + "content": self.content_for_future, + } + if self.tool_calls: + auto_tokenize_target.update({"tool_calls": self.tool_calls}) + self.token_arr = ajet_apply_chat_template( + tokenizer=tokenizer, + conversation=[auto_tokenize_target], + tokenize=True, + tools=tools, + ) + return self.token_arr + + def auto_tokenize_non_first_message(self, tokenizer, tools): + try: + # completion_token_arr will contain generation_prompt header + auto_tokenize_target = { + "role": self.role, + "content": self.content_for_future, + } + if self.tool_calls: + auto_tokenize_target.update({"tool_calls": self.tool_calls}) + if self.tool_call_id: + auto_tokenize_target.update({"tool_call_id": self.tool_call_id}) + text_frag_to = ajet_apply_chat_template( + tokenizer=tokenizer, + conversation=DUMMY_MSG + [auto_tokenize_target], + tokenize=False, + tools=tools, + ) + except Exception as e: + raise ValueError(f"Cannot tokenize {self.role} --- {self.content_for_future}, \n\n Error: {e}") + self.token_arr, _ = self.get_inc_simple( + text_frag_from=ajet_apply_chat_template( + tokenizer=tokenizer, + conversation=DUMMY_MSG, + tokenize=False, + tools=tools, + ), + text_frag_to=text_frag_to, + tokenizer=tokenizer, + ) + return self.token_arr + + @property + def content_for_future(self): + if self._content_for_future == "": + if not self.tool_calls: + logger.exception("content_for_future is not set, or previous llm output is empty!") + self._content_for_future + return self._content_for_future + + @property + def need_training(self): + assert (self.author in NEED_TRAIN_AUTHORS) or (self.author in NON_TRAIN_AUTHORS) or (self.author.endswith("(discard)")), f"author {self.author} is not identified" + return self.author in NEED_TRAIN_AUTHORS + + def generate_content_for_future(self, tokenizer, clip, clip_token_limit=-1): + _content: str = self.content + if clip: + assert clip_token_limit > 0, "clip_token_limit must be set when clip is True" + n_token = len(tokenizer(_content, return_tensors="pt", padding=False)["input_ids"][0]) + if n_token > clip_token_limit: + # 8000 > 4000 + n_char = len(_content) # 10,000 + eps = 100 # token + preserve_percent = (clip_token_limit - eps) / n_token # 3900 / 8000 + n_char_to_preserve = int(n_char * preserve_percent) + _content = _content[:n_char_to_preserve] + "... truncate ..." + self._content_for_future = _content + + def get_loss_mask(self, blackout_token_combo): + if self.need_training: + # keep eos, but blackout everything after eos + msg_token_mask = [1] * len(self.token_arr) + msg_token_mask = blackout_specific_token_ids_first_encounter(msg_token_mask, self.token_arr, blackout_token_combo) + # in normal case, we will blackout everything after the EOS token + # but EOS still participates in the loss calculation + msg_token_mask = blackout_everything_after_eos_but_keep_eos( + mask=msg_token_mask, + token_arr=self.token_arr, + eos_token_id=self.eos_token_id, + ) + # however, if the message does not have eos (e.g., finish_reason: length), we will blackout everything after the EOS token + # including the EOS token + if self.lack_normal_eos: + msg_token_mask = blackout_everything_after_eos_including_eos( + mask=msg_token_mask, + token_arr=self.token_arr, + eos_token_id=self.eos_token_id, + ) + if self.manual_loss_mask_override: + # assert two list is identical + assert len(self.manual_loss_mask_override) == len(msg_token_mask) + assert all(a == b for a, b in zip(self.manual_loss_mask_override, msg_token_mask)) + + return msg_token_mask + else: + msg_token_mask = [0] * len(self.token_arr) + return msg_token_mask + + def get_inc_simple(self, text_frag_from, text_frag_to, tokenizer): + """ + Get the incremental token array from text_frag_from to text_frag_to. + """ + tokenizer_output = tokenizer(text_frag_from, return_tensors="pt", padding=False) + tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist() + token_ids_acc = tokenizer_input_ids + + tokenizer_output = tokenizer(text_frag_to, return_tensors="pt", padding=False) + input_ids = tokenizer_output["input_ids"][0].tolist() + # get the new tokens added in this step + input_id_increment = input_ids[len(token_ids_acc) :] + FN_DEBUG = False + if FN_DEBUG: + overlap_length = 0 + for i in range(len(token_ids_acc)): + if i < len(token_ids_acc) and input_ids[i] == token_ids_acc[i]: + overlap_length += 1 + else: + break + msg = f"previous token length: {len(token_ids_acc)}, overlap token length: {(overlap_length)}, increment token length: {len(input_id_increment)}" + else: + msg = "" + return input_id_increment, msg + + @staticmethod + def check_and_merge_chained_tool_response(ext_msg_array: List["ExtendedMessage"], tokenizer: PreTrainedTokenizer) -> List["ExtendedMessage"]: + """ + Inside a list of ExtendedMessage, + Find consecutive ext msg with role=="tool", then merge them into one ExtendedMessage + + Jinja2 template logic for reference: + + {%- elif message.role == \"tool\" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\ + \ + ' }} + {{- message.content }} + {{- '\ + ' }} + {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %} + {{- '<|im_end|>\ + ' }} + {%- endif %} + {%- endif %} + """ + + def merge_tool_group(group, tokenizer): + if len(group) == 1: + return group[0] + + msg0 = group[0] + merged_content = "".join(f"\n{msg.content}\n\n" for msg in group) + merged_content = merged_content[len("\n") :] + merged_content = merged_content[: -len("\n")] + merged = ExtendedMessage( + author=msg0.author, + role=msg0.role, + content=merged_content, + tokenizer=tokenizer, + token_generator="manual", + build_from_uuid=msg0.uuid, + tools=msg0.tools, + tool_calls=msg0.tool_calls, + token_logprob_arr=msg0.token_logprob_arr, + first_message=msg0.first_message, + ) + # re-compute token_arr + auto_tokenize_targets = [{"role": msg.role, "content": msg.content_for_future} for msg in group] + merged.token_arr, _ = merged.get_inc_simple( + text_frag_from=ajet_apply_chat_template( + tokenizer=tokenizer, + conversation=DUMMY_MSG, + tokenize=False, + tools=merged.tools, + add_generation_prompt=False, + ), + text_frag_to=ajet_apply_chat_template( + tokenizer, + conversation=DUMMY_MSG + auto_tokenize_targets, + tokenize=False, + tools=merged.tools, + add_generation_prompt=False, + ), + tokenizer=tokenizer, + ) + return merged + + groups = [] + current_tool_group = [] + for msg in ext_msg_array: + if msg.role == "tool": + current_tool_group.append(msg) + else: + if current_tool_group: + groups.append(current_tool_group) + current_tool_group = [] + groups.append([msg]) + if current_tool_group: + groups.append(current_tool_group) + + result_ext_msg_array = [merge_tool_group(group, tokenizer) for group in groups] + return result_ext_msg_array diff --git a/ajet/schema/logprob.py b/ajet/schema/logprob.py new file mode 100644 index 00000000..dc736fb8 --- /dev/null +++ b/ajet/schema/logprob.py @@ -0,0 +1,14 @@ +# from typing import Any, Dict, List +# from loguru import logger +# from omegaconf import DictConfig +# from openai.types.chat.chat_completion import ChatCompletion +# from verl import DataProto + + +from pydantic import BaseModel + + +class TokenAndProb(BaseModel): + token_id: int + logprob: float + decoded_string: str diff --git a/ajet/schema/task.py b/ajet/schema/task.py new file mode 100644 index 00000000..6d94796c --- /dev/null +++ b/ajet/schema/task.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, List, Union + +from pydantic import BaseModel, Field + +""" +The basic schema for task_reader module +""" + + +class Task(BaseModel): + main_query: str = Field(default="") + init_messages: List[dict] = Field(default=[]) + task_id: str = Field(default="") + env_type: str = Field(default="") + metadata: dict = Field(default_factory=dict) + + +""" +For workflow execution, include task uuid and gym client if needed +""" + + +class WorkflowTask(BaseModel): + env_type: str = Field(default="") + task_id: str = Field(default="") + task_thread_index: int = Field(default=0) + task_batch_index: int = Field(default=0) + task_tag: str = Field(default="") + episode_uuid: str = Field(default="") + observation_window: dict = Field(default={}) + llm_inference_fn: Any = Field(default=None) + tokenizer: Any = Field(default=None) + task: Task = Field(default=Task()) + gym_env: Any = Field(default=None) # agentscope runtime handle or env service handle + + +""" +workflow output, user should provide as workflow output +""" + + +class WorkflowOutput(BaseModel): + reward: Union[float, List[float], None] = Field(default=None) + is_success: Union[bool, None] = Field(default=None) + metadata: Dict[str, Any] = Field(default_factory=dict) + log_metrics: Dict[str, Union[float, List[float]]] = Field(default_factory=dict) diff --git a/astune/schema/trajectory.py b/ajet/schema/trajectory.py similarity index 51% rename from astune/schema/trajectory.py rename to ajet/schema/trajectory.py index 7e93091d..b8fddd11 100644 --- a/astune/schema/trajectory.py +++ b/ajet/schema/trajectory.py @@ -1,43 +1,47 @@ -import numpy as np from typing import Any, Dict, List, Union + +import numpy as np from pydantic import BaseModel, Field +from loguru import logger class Reward(BaseModel): + # raw reward: the original reward from environment raw_reward: float = Field(default=0.0) + # raw step reward: the original step-wise rewards from environment raw_step_reward: Union[List[float], None] = Field(default=[]) - step_reward: List[float] = Field(default=[]) + # step reward: reward after post-processing, e.g., repeatition penalty + step_reward_arr: List[float] = Field(default=[]) + # advantage values: mean reward is group-wise averaged. e.g. ( (r11+r12+r13)/3 , (r21+r22)/2 ) / 2 step_advantage: List[float] = Field(default=[]) + # simple advantage values: mean reward is sample-wise averaged. e.g. (r11, r12, r13, r21, r22) / 5 step_advantage_simple: List[float] = Field(default=[]) + # the success or not, either 0 or 1. average multiple samples to get success rate success_rate: float = Field(default=0.0) + # llm produce abnormal or illegal output, such as ever repeating the same sentence madness: float = Field(default=0.0) + # description of the reward description: str = Field(default="Outcome 1 denotes success, and 0 denotes failure.") + # metadata for reward metadata: dict = Field(default_factory=dict) @property def performance_reward(self): - if (self.step_reward is not None) and len(self.step_reward) > 0: - res = np.mean(self.step_reward) - # print(f"Performance reward computed as mean of step_reward: {res}") + # performance reward is only used in dynamic rollout + # used to terminate hopeless rollout thread early + # this reward is NOT used in training + if (self.step_reward_arr is not None) and len(self.step_reward_arr) > 0: + res = np.mean(self.step_reward_arr) return res else: return self.raw_reward -class Trajectory(BaseModel): - task_batch_index: int = Field(default=0) - task_tag: str = Field(default="") - - steps: List[dict] = Field(default_factory=list) - query: str = Field(default="") - - is_terminated: bool = Field(default=False) - reward: Reward = Field(default_factory=Reward) - - metadata: dict = Field(default_factory=dict) - @property - def success(self) -> bool: - return self.reward_outcome > 0 + def final_scalar_reward(self): + # to compute scalar reward, we average step_reward_arr + reward = self.step_reward_arr + reward = float(np.mean(reward)) + return reward class Sample(BaseModel): @@ -69,33 +73,33 @@ class Sample(BaseModel): step_reward: float = 0.0 reference_advantage: float = 0.0 - def __init__(self, cmt_tokenized:dict, messages, config, **kwargs): + def __init__(self, tracker_tokenized: dict, messages, config, **kwargs): super().__init__(**kwargs) - self.max_prompt_len = config.astune.data.max_prompt_length - self.max_response_len = config.astune.data.max_response_length - self.max_model_len = config.astune.data.max_response_length + config.astune.data.max_prompt_length + self.max_prompt_len = config.ajet.data.max_prompt_length + self.max_response_len = config.ajet.data.max_response_length + self.max_model_len = config.ajet.data.max_response_length + config.ajet.data.max_prompt_length - self.input_ids = cmt_tokenized["input_ids"] - self.attention_mask = cmt_tokenized["attention_mask"] - self.loss_mask = cmt_tokenized["loss_mask"] - self.position_ids = cmt_tokenized["position_ids"] - self.logprobs = cmt_tokenized["logprobs"] + self.input_ids = tracker_tokenized["input_ids"] + self.attention_mask = tracker_tokenized["attention_mask"] + self.loss_mask = tracker_tokenized["loss_mask"] + self.position_ids = tracker_tokenized["position_ids"] + self.logprobs = tracker_tokenized["logprobs"] - self.prompt_ids = cmt_tokenized["prompt_ids"] - self.prompt_attention_mask = cmt_tokenized["prompt_attention_mask"] - self.prompt_loss_mask = cmt_tokenized["prompt_loss_mask"] - self.prompt_position_ids = cmt_tokenized["prompt_position_ids"] - self.prompt_logprobs = cmt_tokenized["prompt_logprobs"] + self.prompt_ids = tracker_tokenized["prompt_ids"] + self.prompt_attention_mask = tracker_tokenized["prompt_attention_mask"] + self.prompt_loss_mask = tracker_tokenized["prompt_loss_mask"] + self.prompt_position_ids = tracker_tokenized["prompt_position_ids"] + self.prompt_logprobs = tracker_tokenized["prompt_logprobs"] - self.response_ids = cmt_tokenized["response_ids"] - self.response_attention_mask = cmt_tokenized["response_attention_mask"] - self.response_loss_mask = cmt_tokenized["response_loss_mask"] - self.response_position_ids = cmt_tokenized["response_position_ids"] - self.response_logprobs = cmt_tokenized["response_logprobs"] + self.response_ids = tracker_tokenized["response_ids"] + self.response_attention_mask = tracker_tokenized["response_attention_mask"] + self.response_loss_mask = tracker_tokenized["response_loss_mask"] + self.response_position_ids = tracker_tokenized["response_position_ids"] + self.response_logprobs = tracker_tokenized["response_logprobs"] - self.reference_advantage = cmt_tokenized["reference_advantage"] - self.step_reward = cmt_tokenized["step_reward"] + self.reference_advantage = tracker_tokenized["reference_advantage"] + self.step_reward = tracker_tokenized["step_reward"] self.messages = messages @@ -104,7 +108,6 @@ def __init__(self, cmt_tokenized:dict, messages, config, **kwargs): assert len(self.response_ids) != 0, "response_ids should not be empty" def truncate_output_ids(self) -> None: - assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask) assert len(self.prompt_ids) == len(self.prompt_attention_mask) == len(self.prompt_position_ids) == len(self.prompt_loss_mask) == len(self.prompt_logprobs) assert len(self.response_ids) == len(self.response_attention_mask) == len(self.response_position_ids) == len(self.response_loss_mask) == len(self.response_logprobs) @@ -114,22 +117,13 @@ def truncate_output_ids(self) -> None: if len(self.prompt_ids) > self.max_prompt_len: truncate_any = True - print(f"-------------------------------------------------------------------------------------------------------") - print(f"Warning: prompt_ids length {len(self.prompt_ids)} exceeds max_prompt_len {self.max_prompt_len}, truncating.") - print(f"-------------------------------------------------------------------------------------------------------") - raise RuntimeError("Prompt length exceeds maximum allowed length. Please adjust the input data.") - self.prompt_ids = self.prompt_ids[-self.max_prompt_len:] - self.prompt_attention_mask = self.prompt_attention_mask[-self.max_prompt_len:] - self.prompt_position_ids = self.prompt_position_ids[-self.max_prompt_len:] - self.prompt_loss_mask = self.prompt_loss_mask[-self.max_prompt_len:] - self.prompt_logprobs = self.prompt_logprobs[-self.max_prompt_len:] - + raise RuntimeError(f"Warning: prompt_ids length {len(self.prompt_ids)} exceeds max_prompt_len {self.max_prompt_len}, truncating.") if len(self.response_ids) > self.max_response_len: truncate_any = True - print(f"-------------------------------------------------------------------------------------------------------") - print(f"Warning: response_ids length {len(self.response_ids)} exceeds max_response_len {self.max_response_len}, truncating.") - print(f"-------------------------------------------------------------------------------------------------------") + logger.warning("-------------------------------------------------------------------------------------------------------") + logger.warning(f"Warning: response_ids length {len(self.response_ids)} exceeds max_response_len {self.max_response_len}, truncating.") + logger.warning("-------------------------------------------------------------------------------------------------------") self.response_ids = self.response_ids[: self.max_response_len] self.response_attention_mask = self.response_attention_mask[: self.max_response_len] self.response_position_ids = self.response_position_ids[: self.max_response_len] @@ -147,4 +141,4 @@ def discard(self) -> None: """ Discard the experience. """ - raise RuntimeError('Never use this method.') + raise RuntimeError("Never use this method.") diff --git a/ajet/task_judge/base_judge.py b/ajet/task_judge/base_judge.py new file mode 100644 index 00000000..01fc8fae --- /dev/null +++ b/ajet/task_judge/base_judge.py @@ -0,0 +1,9 @@ +from ajet.workflow import WorkflowOutput, WorkflowTask + + +class BaseJudge: + def __init__(self, config): + self.config = config + + def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple: + raise NotImplementedError diff --git a/ajet/task_judge/env_service_as_judge.py b/ajet/task_judge/env_service_as_judge.py new file mode 100644 index 00000000..503ace82 --- /dev/null +++ b/ajet/task_judge/env_service_as_judge.py @@ -0,0 +1,24 @@ +from ajet.task_judge.base_judge import BaseJudge +from ajet.workflow import WorkflowOutput, WorkflowTask + + +class EnvServiceJudge(BaseJudge): + def __init__(self, config): + self.config = config + + def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple: + raw_reward = 0 + + env = workflow_task.gym_env + raw_reward = env.evaluate(workflow_task.episode_uuid, params={"sparse": False}) + if raw_reward >= 1: + is_success = True + else: + is_success = False + + if is_success: + raw_reward = 1.0 + raw_reward * 0.5 + else: + raw_reward = 0.0 + raw_reward * 0.5 + + return raw_reward, is_success diff --git a/ajet/task_judge/rm_auto_grader_judge.py b/ajet/task_judge/rm_auto_grader_judge.py new file mode 100644 index 00000000..d3f2074d --- /dev/null +++ b/ajet/task_judge/rm_auto_grader_judge.py @@ -0,0 +1,387 @@ +""" +RM Gallery Iterative Rubric Judge Integration + +This module integrates RM Gallery's IterativeRubricsGenerator capabilities into ajet's judge system. +It provides a data-driven approach to evaluate workflow outputs using automatically +generated rubrics from training samples. + +Key Features: +- Automatic rubric generation from training/validation samples using iterative Propose-Evaluate-Revise loop +- Support for both pointwise and listwise evaluation modes +- MCR²-based smart sampling for large datasets +- Optional LLM-based categorization to organize rubrics +- Flexible scoring based on LLM-generated rubrics +- Seamless integration with ajet's workflow system +""" + +import asyncio +import json +import os +from typing import List, Optional + +from beast_logger import print_dict +from loguru import logger +from rm_gallery.core.generator.iterative_rubric.generator import ( + IterativeListwiseRubricsGeneratorConfig, + IterativePointwiseRubricsGeneratorConfig, + IterativeRubricsGenerator, +) +from rm_gallery.core.graders.llm_grader import LLMGrader +from rm_gallery.core.graders.schema import GraderMode +from rm_gallery.core.models.dashscope_chat_model import DashScopeChatModel +from rm_gallery.core.models.schema.prompt_template import LanguageEnum + +from ajet.schema.task import Task, WorkflowOutput +from ajet.task_judge.base_judge import BaseJudge + + +class AutoGraderJudge(BaseJudge): + """ + A data-driven judge that uses RM Gallery's IterativeRubricsGenerator to evaluate workflow outputs. + + This judge automatically generates evaluation rubrics from a set of reference samples + and then uses those rubrics to score new workflow outputs. It uses an iterative + Propose-Evaluate-Revise loop to ensure high-quality rubrics. + + Workflow: + 1. Initialize with configuration and reference samples + 2. Generate rubrics from reference samples using iterative refinement (one-time setup) + 3. Evaluate each workflow output against the generated rubrics + + Example Config (in YAML): + task_judge: + # RM Gallery Model Configuration + model_name: "qwen-plus" # or "gpt-4", "claude-3-sonnet", etc. + api_key: "your-api-key" + base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" # optional + + # Rubric Generation Configuration + grader_mode: "pointwise" # or "listwise" + language: "en" # or "zh" + + # Advanced Configuration (optional, uses sensible defaults) + query_specific_generate_number: 1 # number of rubrics per sample (default: 1) + enable_categorization: false # use LLM-based categorization (default: false) + categories_number: 5 # number of categories when categorization enabled (default: 5) + + # Reference samples for rubric generation + input_data_type: "dataset_file" # or other supported types + dataset_file: + training: + file_path: "tutorial/example_rm_auto_grader/rubrics_train.jsonl" + + # Custom field mapping (optional, uses defaults if not specified) + query_field: "main_query" # field in task containing query + answer_field: "final_answer" # field in metadata containing answer + reference_field: "answer" # field in task.metadata containing reference + + # Pointwise mode settings (only for pointwise mode) + min_score: 0 # minimum score + max_score: 10 # maximum score + """ + + def __init__(self, config): + """Initialize the AutoGraderJudge. + + Args: + config: Configuration object containing model and rubric generation settings + """ + super().__init__(config) + + self.config = config + + # Initialize the model FIRST + # Get API key from config or environment + import os + + api_key = getattr(config.ajet.task_judge.rubrics_auto_grader, "api_key", None) or os.getenv("DASHSCOPE_API_KEY") + + self.model = DashScopeChatModel( + model=config.ajet.task_judge.rubrics_auto_grader.model_name, + api_key=api_key, + stream=False, + enable_thinking=False, + ) + + # Parse config (needs self.model to be initialized) + self.generator_config = self._parse_config() + + # Storage for generated grader + self.llm_grader: Optional[LLMGrader] = None + self.rubrics_generated = False + + # Field mappings for data extraction + self.query_field = getattr(config.ajet.task_judge.rubrics_auto_grader, "query_field", "main_query") + self.answer_field = getattr(config.ajet.task_judge.rubrics_auto_grader, "answer_field", "final_answer") + self.reference_field = getattr(config.ajet.task_judge.rubrics_auto_grader, "reference_field", "answer") + + logger.info(f"AutoGraderJudge initialized with mode={self.generator_config.grader_mode.value}, " f"language={self.generator_config.language.value}") + + def _parse_config( + self, + ) -> IterativePointwiseRubricsGeneratorConfig | IterativeListwiseRubricsGeneratorConfig: + """Parse ajet config into IterativeRubricsGeneratorConfig.""" + judge_config = self.config.ajet.task_judge.rubrics_auto_grader + + # Parse grader mode + grader_mode_str = getattr(judge_config, "grader_mode", "pointwise").lower() + grader_mode = GraderMode.POINTWISE if grader_mode_str == "pointwise" else GraderMode.LISTWISE + + # Parse language + language_str = getattr(judge_config, "language", "en").upper() + language = LanguageEnum.ZH if language_str == "ZH" else LanguageEnum.EN + + # Common configuration parameters + common_config = { + "model": self.model, + "grader_name": getattr(judge_config, "grader_name", "RM Iterative Rubric Grader"), + "language": language, + "enable_categorization": getattr(judge_config, "enable_categorization", False), + "query_specific_generate_number": getattr(judge_config, "query_specific_generate_number", 1), + "categories_number": getattr(judge_config, "categories_number", 5), + "max_retries": getattr(judge_config, "max_retries", 5), + "max_epochs": getattr(judge_config, "max_epochs", 3), + "batch_size": getattr(judge_config, "batch_size", 10), + "mcr_batch_size": getattr(judge_config, "mcr_batch_size", 10), + "min_increment_threshold": getattr(judge_config, "min_increment_threshold", 0.002), + "patience": getattr(judge_config, "patience", 2), + "max_iterations": getattr(judge_config, "max_iterations", 50), + "max_total_rubrics": getattr(judge_config, "max_total_rubrics", 200), + "custom_evaluation_prompt": getattr(judge_config, "custom_evaluation_prompt", None), + } + + # Create mode-specific config + if grader_mode == GraderMode.POINTWISE: + return IterativePointwiseRubricsGeneratorConfig( + **common_config, + min_score=getattr(judge_config, "min_score", 0), + max_score=getattr(judge_config, "max_score", 10), + ) + else: + return IterativeListwiseRubricsGeneratorConfig(**common_config) + + async def read_reference_samples_from_dataset(self) -> List[Task]: + # read dataset from config + from ajet.task_reader import RouterTaskReader + + reader = RouterTaskReader( + reader_type=self.config.ajet.task_judge.rubrics_auto_grader.input_data_type, + reader_config=self.config.ajet.task_judge.rubrics_auto_grader, + ) + return reader.task_reader.get_training_tasks() + + async def generate_rubrics_from_samples(self, reference_samples: List[Task] = []) -> None: + """ + Generate evaluation rubrics from reference samples using iterative refinement. + + This method should be called once during initialization with a set of + reference tasks that represent the types of problems to be evaluated. + + Args: + reference_samples: List of Task objects with reference data + """ + + if len(reference_samples) == 0: + reference_samples = await self.read_reference_samples_from_dataset() + + logger.info(f"Generating rubrics from {len(reference_samples)} reference samples...") + + # Convert Task samples to the format expected by IterativeRubricsGenerator + training_dataset = [] + for sample in reference_samples: + data_item = self._task_to_training_data(sample) + if data_item: + training_dataset.append(data_item) + + if not training_dataset: + raise ValueError("No valid training data could be created from reference samples") + + logger.info(f"Created {len(training_dataset)} training samples for rubric generation") + + # Create IterativeRubricsGenerator + generator = IterativeRubricsGenerator(config=self.generator_config) + + # Generate rubrics and get LLMGrader + self.llm_grader = await generator.generate(dataset=training_dataset) + + # Save the grader + experiment_dir = self.config.ajet.experiment_dir + grader_save_dir = os.path.join(experiment_dir, "auto_grader.json") + # make dirs if not exist + os.makedirs(experiment_dir, exist_ok=True) + print_dict({"message": "Saving generated grader config to", "path": grader_save_dir}) + json.dump( + self.llm_grader.to_dict(), + open(grader_save_dir, "w", encoding="utf-8"), + indent=4, + ensure_ascii=False, + ) + + self.rubrics_generated = True + + logger.info("Rubrics generated successfully!") + logger.info(f"Generated rubrics:\n{self.llm_grader.rubrics}") + + async def load_rubrics_from_cache(self) -> None: + """ + Load a pre-generated grader configuration from file. + + Args: + grader_config_path: Path to the JSON file containing the grader config + """ + + # Load grader config and inject model + try: + experiment_dir = self.config.ajet.experiment_dir + grader_save_dir = os.path.join(experiment_dir, "auto_grader.json") + grader_config = json.load(open(grader_save_dir, "r", encoding="utf-8")) + grader_config["model"] = self.model + self.llm_grader = LLMGrader.from_config(grader_config) + except Exception: + logger.exception("Failed to load grader config from") + await self.generate_rubrics_from_samples([]) + + def _task_to_training_data(self, task: Task) -> Optional[dict]: + """ + Convert Task to training data format for IterativeRubricsGenerator. + + Args: + task: The workflow task containing query and reference with labels + + Returns: + Training data dict or None if conversion fails + + Expected formats: + Pointwise: {"query": str, "response": str, "label_score": int} + Listwise: {"query": str, "responses": List[str], "label_rank": List[int]} + """ + try: + # Extract query + query = getattr(task, self.query_field, "") + if not query and hasattr(task, "metadata"): + query = task.metadata.get(self.query_field, "") + + if not query: + raise ValueError(f"Query field '{self.query_field}' not found in task") + + metadata = task.metadata if hasattr(task, "metadata") else {} + + if self.generator_config.grader_mode == GraderMode.POINTWISE: + # Pointwise: expect metadata with "answer" and "score" + if "answer" in metadata and "score" in metadata: + return { + "query": query, + "response": metadata["answer"], + "label_score": metadata["score"], + } + else: + raise ValueError(f"Metadata must contain 'answer' and 'score' for pointwise training data in task {task.task_id}") + + else: # LISTWISE + # Listwise: expect metadata with "candidates" containing list of {answer, rank} + if "candidates" in metadata and isinstance(metadata["candidates"], list): + responses = [] + label_ranks = [] + for candidate in metadata["candidates"]: + responses.append(candidate["answer"]) + label_ranks.append(candidate["rank"]) + + return { + "query": query, + "responses": responses, + "label_rank": label_ranks, + } + else: + raise ValueError(f"Metadata must contain 'candidates' list for listwise training data in task {task.task_id}") + + except Exception as e: + logger.warning(f"Failed to convert task to training data: {e}") + return None + + async def _async_compute_reward(self, task: Task, workflow_output: WorkflowOutput | List[WorkflowOutput]): + """ + Asynchronously compute reward using the generated rubrics. + + Args: + task: The task being evaluated + workflow_output: Single output for pointwise, or list of outputs for listwise + + Returns: + For pointwise: tuple (raw_reward, is_success) + For listwise: List of ranking results + """ + if not self.rubrics_generated or self.llm_grader is None: + raise RuntimeError("Rubrics have not been generated yet. " "Call generate_rubrics_from_samples() first.") + + # Extract query + query = getattr(task, self.query_field, "") + if not query and hasattr(task, "metadata"): + query = task.metadata.get(self.query_field, "") + + # Evaluate using LLMGrader + try: + if self.generator_config.grader_mode == GraderMode.POINTWISE: + # Pointwise evaluation: single output + if isinstance(workflow_output, list): + # If list provided, evaluate first output + answer = workflow_output[0].metadata.get(self.answer_field, "") + else: + answer = workflow_output.metadata.get(self.answer_field, "") + + result = await self.llm_grader.aevaluate(query=query, answer=answer) + return result + + else: # LISTWISE + # Listwise evaluation: multiple outputs + if not isinstance(workflow_output, list): + logger.error("Listwise mode requires a list of workflow outputs") + return None + + # Format responses for listwise evaluation + responses = [] + for output in workflow_output: + responses.append(output.metadata.get(self.answer_field, "")) + + # Format answer as required by listwise grader + answer = "\n\n".join([f"Response {i+1}:\n{resp}" for i, resp in enumerate(responses)]) + + result = await self.llm_grader.aevaluate(query=query, answer=answer, num_responses=len(responses)) + return result + + except Exception as e: + logger.error(f"Error during evaluation: {e}") + return None + + def compute_reward(self, task: Task, workflow_output: WorkflowOutput) -> tuple: + """ + Compute reward for a workflow output (synchronous wrapper). + + This is the main interface called by ajet's workflow system. + + Args: + task: The task being evaluated + workflow_output: The output to evaluate + + Returns: + tuple: (raw_reward, is_success) + """ + # Check if we're already in an async context + try: + loop = asyncio.get_running_loop() + # If we get here, we're in an async context + # We need to use nest_asyncio or raise an error + try: + import nest_asyncio + + nest_asyncio.apply() + return loop.run_until_complete(self._async_compute_reward(task, workflow_output)) + except ImportError: + raise RuntimeError("compute_reward() was called from an async context. " "Please use 'await judge._async_compute_reward(task, output)' instead, " "or install nest_asyncio: pip install nest_asyncio") + except RuntimeError: + # No event loop running, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self._async_compute_reward(task, workflow_output)) + finally: + loop.close() diff --git a/ajet/task_reader/__init__.py b/ajet/task_reader/__init__.py new file mode 100644 index 00000000..2d7d7322 --- /dev/null +++ b/ajet/task_reader/__init__.py @@ -0,0 +1,126 @@ +from typing import List + +import datasets +import numpy as np + +from ajet.schema.task import Task +from ajet.task_reader.data_generator_reader import DataGeneratorTaskReader +from ajet.task_reader.env_service_reader import EnvServiceTaskReader +from ajet.task_reader.hf_dataset_reader import HuggingFaceTaskReader +from ajet.task_reader.jsonl_reader import JsonlTaskReader +from ajet.task_reader.task_reader_base import BaseTaskReader +from ajet.task_reader.tracing_reader import TracingReader + + +class RandomDummyTaskReader(BaseTaskReader): + def __init__(self, reader_config): + super().__init__(reader_config) + + def _load_dataset_split(self, dataset_name: str, split: str) -> List[Task]: + tasks = [] + # Save the current random state + original_state = np.random.get_state() + np.random.seed(42) + random_number = [x for x in range(1000)] + # shuffle + np.random.shuffle(random_number) + for idx in random_number: + task = Task( + main_query=f"[dummy task @ {idx}]", + init_messages=[], + task_id=str(idx), + env_type="no_env", + metadata={"random_number": idx}, + ) + tasks.append(task) + # Restore the original random state + np.random.set_state(original_state) + return tasks + + def get_training_tasks(self) -> List[Task]: + return self._load_dataset_split("dataset_name", "split") + + def get_validation_tasks(self) -> List[Task]: + return self._load_dataset_split("dataset_name", "split") + + +class RouterTaskReader(BaseTaskReader): + def __init__(self, reader_type, reader_config): + super().__init__(None) + + task_reader_type = reader_type + if task_reader_type == "env_service": + self.task_reader = EnvServiceTaskReader(reader_config) + elif task_reader_type == "jsonl_dataset_file": + self.task_reader = JsonlTaskReader(reader_config) + elif task_reader_type == "huggingface_dat_repo": + self.task_reader = HuggingFaceTaskReader(reader_config) + elif task_reader_type == "tracing": + self.task_reader = TracingReader(reader_config) + elif task_reader_type == "data_generation": + self.task_reader = DataGeneratorTaskReader(reader_config) + elif task_reader_type == "random_dummy": + self.task_reader = RandomDummyTaskReader(reader_config) + else: + raise ValueError(f"Unsupported task reader type: {task_reader_type}") + + def get_training_tasks(self) -> List[Task]: + result = self.task_reader.get_training_tasks() + np.random.shuffle(result) # type: ignore + return result + + def get_validation_tasks(self) -> List[Task]: + result = self.task_reader.get_validation_tasks() + np.random.shuffle(result) # type: ignore + return result + + +def task_to_standard_dataset(tasks: List[Task]) -> datasets.Dataset: + """ + Convert a list of Task objects to a standard Hugging Face Dataset. + + Args: + tasks (List[Task]): List of Task objects. + + Returns: + datasets.Dataset: Hugging Face Dataset containing the tasks. + """ + data = { + "task_id": [], + "main_query": [], + "init_messages": [], + "env_type": [], + "metadata": [], + } + + for task in tasks: + data["task_id"].append(task.task_id) + data["main_query"].append(task.main_query) + data["init_messages"].append(task.init_messages) + data["env_type"].append(task.env_type) + data["metadata"].append(task.metadata) + + return datasets.Dataset.from_dict(data) + + +def dict_to_ajet_task(task_dict: dict) -> Task: + """ + Convert a dictionary to a Task object. + + Args: + task_dict (dict): Dictionary containing task fields. + + Returns: + Task: Task object created from the dictionary. + """ + for vip_key in ["main_query", "task_id", "env_type", "metadata", "init_messages"]: + if vip_key not in task_dict: + raise ValueError(f"Key {vip_key} not found in task.raw_task") + + return Task( + main_query=task_dict.get("main_query", ""), + init_messages=task_dict.get("init_messages", []), + task_id=task_dict.get("task_id", ""), + env_type=task_dict.get("env_type", ""), + metadata=task_dict.get("metadata", {}), + ) diff --git a/ajet/task_reader/data_generator_reader.py b/ajet/task_reader/data_generator_reader.py new file mode 100644 index 00000000..119da9b2 --- /dev/null +++ b/ajet/task_reader/data_generator_reader.py @@ -0,0 +1,380 @@ +import hashlib +import json +import math +import os +import random +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List + +import dotenv +from loguru import logger +from tqdm import tqdm + +from ajet.data_generator.knowledge_augmentation import KnowledgeAugmentor +from ajet.data_generator.task_augmentation import TaskAugmentor +from ajet.schema.task import Task +from ajet.task_reader.document_reader.doc_reader import DocReader +from ajet.task_reader.task_reader_base import BaseTaskReader +from ajet.task_reader.tracing_reader.filters.deduplication_filter import ( + DeduplicationFilter, +) + +dotenv.load_dotenv() + + +class DataGeneratorTaskReader(BaseTaskReader): + """ + Enhanced version of TaskReaderDataGenerator with multi-threading support, + progress bars, and improved batch calculation. + """ + + def __init__(self, reader_config): + super().__init__(reader_config) + self.reader_config = reader_config + project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + dataset_dir = os.path.join(project_root, "dataset/jsonl") + os.makedirs(dataset_dir, exist_ok=True) + + # Build a cache key based on generation-related config to avoid rigid filenames + document_path = getattr(reader_config.data_generation.document_reader, "document_path", None) + # Convert document_path to a hashable string representation + if isinstance(document_path, (list, tuple)): + document_path_str = ",".join(sorted(str(p) for p in document_path)) + elif document_path is not None: + document_path_str = str(document_path) + else: + document_path_str = "" + + cache_config = { + "task_num": reader_config.data_generation.task_num, + "num_workers": getattr(reader_config.data_generation, "num_workers", 32), + "query_reader_type": getattr(reader_config.data_generation.query_reader, "type", None), + "document_reader": document_path_str, + "deduplication_filter": { + "similarity_threshold": getattr( + reader_config.data_generation.deduplication_filter.params, + "similarity_threshold", + None, + ), + "db_path": getattr(reader_config.data_generation.deduplication_filter.params, "db_path", None), + "model": getattr(reader_config.data_generation.deduplication_filter.params, "model", None), + }, + } + cache_key_str = json.dumps(cache_config, sort_keys=True, ensure_ascii=False) + cache_key = hashlib.md5(cache_key_str.encode("utf-8")).hexdigest()[:8] + + self.generated_train_file = os.path.join(dataset_dir, f"generated_train_tasks_{cache_key}.jsonl") + self.generated_valid_file = os.path.join(dataset_dir, f"generated_valid_tasks_{cache_key}.jsonl") + + # Get number of workers from config, default to 32 + self.num_workers = getattr(reader_config.data_generation, "num_workers", 32) + + # Thread-safe lock for shared resources + self.lock = threading.Lock() + + # Initialize duplicate filter + if self.reader_config.data_generation.deduplication_filter.enabled: + self.duplicate_filter = DeduplicationFilter( + similarity_threshold=self.reader_config.data_generation.deduplication_filter.params.similarity_threshold, + db_path=self.reader_config.data_generation.deduplication_filter.params.db_path, + model=self.reader_config.data_generation.deduplication_filter.params.model, + api_key=self.reader_config.data_generation.deduplication_filter.params.api_key, + base_url=self.reader_config.data_generation.deduplication_filter.params.base_url, + ) + else: + self.duplicate_filter = None + # Initialize task reader + from ajet.task_reader import RouterTaskReader + + task_reader = RouterTaskReader( + reader_type=self.reader_config.data_generation.query_reader.type, + reader_config=self.reader_config.data_generation.query_reader, + ) + self.original_tasks = task_reader.get_training_tasks() + + # Check cache files and load/generate accordingly + train_cache_exists = os.path.exists(self.generated_train_file) + valid_cache_exists = os.path.exists(self.generated_valid_file) + + # Load validation tasks from cache if available + if valid_cache_exists: + try: + logger.info(f"Validation cache found: {self.generated_valid_file}") + self.doc_tasks = self._read_jsonl_file(self.generated_valid_file) + logger.info(f"Loaded {len(self.doc_tasks)} validation tasks from cache") + except Exception as e: + logger.error(f"Error loading validation cache: {e}") + self.doc_tasks = None + else: + self.doc_tasks = None + + # Load training tasks from cache if available + if train_cache_exists: + try: + logger.info(f"Training cache found: {self.generated_train_file}") + self.new_tasks = self._read_jsonl_file(self.generated_train_file) + logger.info(f"Loaded {len(self.new_tasks)} training tasks from cache") + except Exception as e: + logger.error(f"Error loading training cache: {e}") + self.new_tasks = None + else: + self.new_tasks = None + + # Generate missing tasks + self._generate_and_save_tasks() + + def _generate_document_tasks_worker(self, args): + """ + Worker function for generating document-based tasks. + + Args: + args: Tuple containing (batch_index, document, knowledge_augmentor) + + Returns: + Tuple of (batch_index, generated_tasks, error_message) + """ + batch_index, document, knowledge_augmentor = args + try: + tasks = knowledge_augmentor.generate_task(source_task=None, document=document) + return batch_index, tasks, None + except Exception as e: + error_msg = f"Error generating document batch {batch_index}: {e}" + return batch_index, [], error_msg + + def _generate_augmented_tasks_worker(self, args): + """ + Worker function for generating augmented tasks. + + Args: + args: Tuple containing (task_index, source_task, document, task_augmentor) + + Returns: + Tuple of (task_index, generated_task, error_message) + """ + task_index, source_task, document, task_augmentor = args + try: + new_task = task_augmentor.generate_task(source_task=source_task, document=document) + return task_index, new_task, None + except Exception as e: + error_msg = f"Error generating task {task_index}: {e}" + return task_index, None, error_msg + + def _generate_and_save_tasks(self): + """ + Enhanced version with selective generation based on cache availability. + """ + logger.info(f"Using {self.num_workers} workers for task generation") + + document_reader = DocReader(self.reader_config) + documents = document_reader.get_document() + task_num = self.reader_config.data_generation.task_num + + # Phase 1: Generate document-based tasks only if not cached + if self.doc_tasks is None and documents is not None: + logger.info("Phase 1: Generating document-based tasks for validation...") + + task_augmentor = TaskAugmentor(self.reader_config) + knowledge_augmentor = KnowledgeAugmentor(self.reader_config) + + # Calculate batches using ceiling division + N = 10 # 10 is the hyperparameter we found that produces relatively stable outputs, same with knowledge_augmentation + doc_task_rounds = math.ceil(task_num / N) + logger.info(f"Generating {doc_task_rounds} document-based task batches (ceil({task_num}/10))") + + self.doc_tasks = [] + + # Prepare arguments for workers + doc_worker_args = [] + for i in range(doc_task_rounds): + document = documents[i % len(documents)] + doc_worker_args.append((i, document, knowledge_augmentor)) + + # Execute document task generation with progress bar + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + # Submit all tasks + future_to_batch = {executor.submit(self._generate_document_tasks_worker, args): args[0] for args in doc_worker_args} + + # Process results with progress bar + with tqdm(total=doc_task_rounds, desc="Document tasks", unit="batch") as pbar: + for future in as_completed(future_to_batch): + batch_index, tasks, error_msg = future.result() + + if error_msg: + logger.error(f"\n{error_msg}") + else: + with self.lock: + self.doc_tasks.extend(tasks) + + pbar.update(1) + if self.duplicate_filter is not None: + self.doc_tasks = self.duplicate_filter.filter_sync(self.doc_tasks) + logger.info(f"Generated {len(self.doc_tasks)} document-based tasks") + + # Save doc_tasks as validation tasks cache + if self.doc_tasks: + logger.info(f"Saving {len(self.doc_tasks)} validation tasks to cache: {self.generated_valid_file}") + self._save_tasks_to_jsonl(self.doc_tasks, self.generated_valid_file) + else: + logger.info("Phase 1: Skipping document task generation (using cached validation tasks)") + + # Phase 2: Generate augmented tasks only if not cached + if self.new_tasks is None: + logger.info("Phase 2: Generating augmented tasks using original + document tasks...") + + task_augmentor = TaskAugmentor(self.reader_config) + + self.new_tasks = [] + + # Combine original tasks and doc tasks for source task selection + if not self.original_tasks: + self.original_tasks = [] + if not self.doc_tasks: + self.doc_tasks = [] + combined_source_tasks = self.original_tasks + self.doc_tasks + logger.info(f"Using {len(combined_source_tasks)} source tasks ({len(self.original_tasks)} original + {len(self.doc_tasks)} document-based)") + + # Prepare arguments for workers + aug_worker_args = [] + for i in range(task_num): + source_task = random.choice(combined_source_tasks) + # Use a random document for augmentation + document = random.choice(documents) if documents else None + aug_worker_args.append((i, source_task, document, task_augmentor)) + + # Execute augmented task generation with progress bar + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + # Submit all tasks + future_to_task = {executor.submit(self._generate_augmented_tasks_worker, args): args[0] for args in aug_worker_args} + + # Process results with progress bar + with tqdm(total=task_num, desc="Augmented tasks", unit="task") as pbar: + for future in as_completed(future_to_task): + task_index, new_task, error_msg = future.result() + + if error_msg: + logger.error(f"\n{error_msg}") + elif new_task: + with self.lock: + self.new_tasks.append(new_task) + + pbar.update(1) + if self.duplicate_filter is not None: + self.new_tasks = self.duplicate_filter.filter_sync(self.new_tasks) + logger.info(f"Generated {len(self.new_tasks)} augmented tasks") + + # Save training tasks + if self.new_tasks: + logger.info(f"Saving {len(self.new_tasks)} training tasks to cache: {self.generated_train_file}") + self._save_tasks_to_jsonl(self.new_tasks, self.generated_train_file) + else: + logger.warning("No training tasks generated successfully") + else: + logger.info("Phase 2: Skipping training task generation (using cached training tasks)") + + logger.info(f"Task generation complete: {len(self.new_tasks)} training tasks, {len(self.doc_tasks)} validation tasks") + + def _read_jsonl_file(self, file_path): + """ + Read tasks from a JSONL file. + + Args: + file_path (str): Path to the JSONL file. + + Returns: + List[Task]: List of Task objects. + """ + tasks = [] + try: + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): # Skip empty lines + task_data = json.loads(line) + # Create a Task object from the JSON data + task = Task( + main_query=task_data.get("main_query", "[not defined]"), + init_messages=task_data.get("init_messages", []), + task_id=task_data.get("task_id", ""), + env_type=task_data.get("env_type", "no_env"), + metadata=task_data.get("metadata", task_data), + ) + tasks.append(task) + except FileNotFoundError: + raise ValueError(f"JSONL file not found: {file_path}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in file {file_path}: {str(e)}") + + if len(tasks) == 0: + raise ValueError(f"No tasks found in file: {file_path}") + + return tasks + + def _save_tasks_to_jsonl(self, tasks: List[Task], file_path: str): + """ + Save tasks to a JSONL file with progress bar. + + Args: + tasks (List[Task]): List of Task objects to save. + file_path (str): Path to the output JSONL file. + """ + try: + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, "w", encoding="utf-8") as f: + with tqdm(total=len(tasks), desc="Saving tasks", unit="task") as pbar: + for task in tasks: + # Convert Task object to dictionary + task_data = { + "main_query": task.main_query, + "init_messages": task.init_messages, + "task_id": task.task_id, + "env_type": task.env_type, + "metadata": task.metadata, + } + # Write as JSON line + f.write(json.dumps(task_data, ensure_ascii=False) + "\n") + pbar.update(1) + + except Exception as e: + raise ValueError(f"Error saving tasks to {file_path}: {str(e)}") + + def get_training_tasks(self) -> List[Task]: + """ + Get training tasks from data generation. + + Returns: + List[Task]: List of training Task objects. + """ + return self.new_tasks + + def get_validation_tasks(self) -> List[Task]: + """ + Get validation tasks from data generation. + Now returns document-based tasks as validation tasks. + + Returns: + List[Task]: List of validation Task objects (doc_tasks). + """ + return getattr(self, "doc_tasks", []) + self.original_tasks + + def get_generation_stats(self) -> dict: + """ + Get statistics about the task generation process. + + Returns: + dict: Statistics including worker count, batch info, and task counts + """ + task_num = self.reader_config.data_generation.task_num + doc_task_rounds = math.ceil(task_num / 10) + + return { + "num_workers": self.num_workers, + "target_task_num": task_num, + "calculated_batches": doc_task_rounds, + "doc_tasks_generated": len(getattr(self, "doc_tasks", [])), + "augmented_tasks_generated": len(getattr(self, "new_tasks", [])), + "original_tasks_count": len(self.original_tasks), + "validation_tasks_count": len(getattr(self, "doc_tasks", [])), + "combined_source_tasks_count": len(self.original_tasks) + len(getattr(self, "doc_tasks", [])), + } diff --git a/ajet/task_reader/document_reader/doc_reader.py b/ajet/task_reader/document_reader/doc_reader.py new file mode 100644 index 00000000..bf25b5bb --- /dev/null +++ b/ajet/task_reader/document_reader/doc_reader.py @@ -0,0 +1,258 @@ +import hashlib +import json +import os +import re +import uuid +from pathlib import Path +from typing import List, Union + +from loguru import logger + +try: + from unstructured.partition.auto import partition +except Exception: + logger.warning("Cannot import dependency `unstructured`") + +from ajet.schema.document import Document +from ajet.task_reader.document_reader.document_reader_base import ( + DocReaderBase, +) + + +class DocReader(DocReaderBase): + """ + Enhanced document reader with file hash caching support and document chunking capabilities. + """ + + def __init__(self, config): + super().__init__(config) + self.cache_enabled = getattr(config.data_generation.document_reader, "cache_enabled", True) + self.chunk_size = getattr(config.data_generation.document_reader, "chunk_size", 10240) + self.split_by = getattr(config.data_generation.document_reader, "split_by", "sentence") + + def load_document(self, source: str, languages=["eng"]) -> str: + """ + Load text from a file with caching support. + """ + if not self.cache_enabled: + return self._parse_document(source, languages) + + # Calculate file hash + file_hash = self._calculate_file_hash(source) + if not file_hash: + return self._parse_document(source, languages) + + # Generate cache path (include chunking parameters in cache key) + cache_path = self._get_cache_path(source, file_hash, languages) + + # Try to load from cache + cached_content = self._load_from_cache(cache_path) + if cached_content: + logger.info(f"Cache hit: {Path(cache_path).name}") + return cached_content + + # Cache miss, parse document + logger.info(f"Cache miss: Parsing {Path(source).name}...") + text = self._parse_document(source, languages) + + # Save to cache + self._save_to_cache(cache_path, text) + logger.info(f"Cached to: {Path(cache_path).name}") + + return text + + def _parse_document(self, source: str, languages: List[str]) -> str: + """Parse document using unstructured.""" + text_pages = partition(source, languages=languages) + if not text_pages: + raise ValueError(f"No extractable text found in file: {source}") + return "\n\n".join([str(sub) for sub in text_pages]) + + def _calculate_file_hash(self, file_path: str) -> str: + """Calculate SHA256 hash of a file.""" + try: + hash_sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + return hash_sha256.hexdigest() + except Exception: + return "" + + def _get_cache_path(self, source_path: str, file_hash: str, languages: List[str]) -> str: + """Generate cache file path with chunking parameters.""" + source_path = Path(source_path) + lang_suffix = "_" + "_".join(sorted(languages)) if languages else "" + + # Include chunking parameters in cache filename + chunk_suffix = "" + if self.chunk_size: + chunk_suffix = f"_chunk{self.chunk_size}_{self.split_by}" + + cache_filename = f"{source_path.stem}.{file_hash[:16]}{lang_suffix}{chunk_suffix}.cache.json" + return str(source_path.parent / cache_filename) + + def _load_from_cache(self, cache_path: str) -> Union[str, None]: + """Load cached content.""" + try: + if os.path.exists(cache_path): + with open(cache_path, "r", encoding="utf-8") as f: + cache_data = json.load(f) + return cache_data.get("content", "") + except Exception: + pass + return None + + def _save_to_cache(self, cache_path: str, content: str) -> bool: + """Save content to cache.""" + try: + cache_data = {"content": content} + with open(cache_path, "w", encoding="utf-8") as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + return True + except Exception: + return False + + def _split_text_by_sentences(self, text: str, chunk_size: int) -> List[str]: + """ + Split text by sentences with specified chunk size. + """ + sentence_endings = r"[.!?。!?]+" + sentences = re.split(f"({sentence_endings})", text) + combined_sentences = [] + for i in range(0, len(sentences), 2): + sentence = sentences[i].strip() + if i + 1 < len(sentences): + sentence += sentences[i + 1] + if sentence: + combined_sentences.append(sentence) + + chunks = [] + current_chunk = [] + current_length = 0 + + for sentence in combined_sentences: + sentence_length = len(sentence) + + if current_length + sentence_length > chunk_size and current_chunk: + chunks.append(" ".join(current_chunk)) + current_chunk = [sentence] + current_length = sentence_length + else: + current_chunk.append(sentence) + current_length += sentence_length + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return chunks + + def _split_text_by_paragraphs(self, text: str, chunk_size: int) -> List[str]: + """ + Split text by paragraphs with specified chunk size. + """ + paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] + + chunks = [] + current_chunk = [] + current_length = 0 + + for paragraph in paragraphs: + paragraph_length = len(paragraph) + + if current_length + paragraph_length > chunk_size and current_chunk: + chunks.append("\n\n".join(current_chunk)) + current_chunk = [paragraph] + current_length = paragraph_length + else: + current_chunk.append(paragraph) + current_length += paragraph_length + + if current_chunk: + chunks.append("\n\n".join(current_chunk)) + + return chunks + + def _split_text_by_characters(self, text: str, chunk_size: int) -> List[str]: + """ + Split text by characters with specified chunk size. + """ + chunks = [] + for i in range(0, len(text), chunk_size): + chunk = text[i : i + chunk_size] + if chunk.strip(): + chunks.append(chunk) + return chunks + + def _chunk_document(self, text: str) -> List[str]: + """ + Split document into chunks based on chunk_size and split_by parameters. + """ + if not self.chunk_size or self.chunk_size <= 0: + return [text] + + if self.split_by == "sentence": + return self._split_text_by_sentences(text, self.chunk_size) + elif self.split_by == "paragraph": + return self._split_text_by_paragraphs(text, self.chunk_size) + elif self.split_by == "character": + return self._split_text_by_characters(text, self.chunk_size) + else: + logger.warning(f"Unknown split_by value '{self.split_by}', using 'sentence' as default") + return self._split_text_by_sentences(text, self.chunk_size) + + def _parser_document(self, raw_document: str, source_path: str = None) -> List[Document]: + """ + Parse raw document into Document objects, with chunking support. + Each chunk from the same source document will have the same group_id. + """ + chunks = self._chunk_document(raw_document) + documents = [] + + # Generate a unique group_id for all chunks from the same source document + group_id = str(uuid.uuid4()) + source_name = Path(source_path).name if source_path else "unknown" + + for i, chunk in enumerate(chunks): + doc_id = str(uuid.uuid4()) + metadata = { + "group_id": group_id, + "source_file": source_name, + "chunk_index": i, + "total_chunks": len(chunks), + "chunk_size": self.chunk_size, + "split_by": self.split_by, + } + documents.append(Document(doc_id=doc_id, content=chunk, metadata=metadata)) + + return documents + + def get_document(self) -> Union[List[Document], None]: + """ + Get all documents with chunking support. + Each source file will generate chunks with the same group_id. + """ + # Safely retrieve document_path from config. If it's missing or falsy, return None. + file_paths = getattr(self.config.data_generation.document_reader, "document_path", None) + if not file_paths: + return None + + # Normalize single string/Path into a list of paths + if isinstance(file_paths, (str, Path)): + file_paths = [str(file_paths)] + + # Ensure we have a concrete list (in case it's a generator or other iterable) + try: + file_paths = list(file_paths) + except Exception: + file_paths = [file_paths] + + all_documents = [] + + for file_path in file_paths: + raw_doc = self.load_document(file_path, languages=list(self.config.data_generation.document_reader.languages)) + # _parser_document now returns a list of documents (chunks) with group_id + documents = self._parser_document(raw_doc, source_path=file_path) + all_documents.extend(documents) + + return all_documents diff --git a/ajet/task_reader/document_reader/document_reader_base.py b/ajet/task_reader/document_reader/document_reader_base.py new file mode 100644 index 00000000..ab7db97c --- /dev/null +++ b/ajet/task_reader/document_reader/document_reader_base.py @@ -0,0 +1,11 @@ +from typing import List, Optional + +from ajet.schema.document import Document + + +class DocReaderBase: + def __init__(self, config): + self.config = config + + def get_document(self) -> Optional[List[Document]]: + raise NotImplementedError diff --git a/ajet/task_reader/env_service_reader.py b/ajet/task_reader/env_service_reader.py new file mode 100644 index 00000000..9abfde0b --- /dev/null +++ b/ajet/task_reader/env_service_reader.py @@ -0,0 +1,36 @@ +from ajet.schema.task import Task +from ajet.task_reader.task_reader_base import BaseTaskReader +from ajet.utils.env_service_client.env_client_ng import EnvClient + + +class EnvServiceTaskReader(BaseTaskReader): + def __init__(self, reader_config): + super().__init__(reader_config) + self.reader_config = reader_config + + def get_tasks(self, split): + env_url = self.reader_config.env_service.env_url + env_type = self.reader_config.env_service.env_type + env_service_client = EnvClient(base_url=env_url) + task_id_array = env_service_client.get_env_profile(env_type, split=split) + if len(task_id_array) == 0: + raise ValueError(f"No task_id found for env_type: {env_type}, split: {split}, Please check connection to {env_url}") + tasks = [ + Task( + main_query="[not defined]", + init_messages=[], + task_id=str(task_id), + env_type=env_type, + metadata={}, + ) + for task_id in task_id_array + ] + return tasks + + def get_validation_tasks(self): + split = self.reader_config.env_service.validation_split + return self.get_tasks(split=split) + + def get_training_tasks(self): + split = self.reader_config.env_service.training_split + return self.get_tasks(split=split) diff --git a/ajet/task_reader/hf_dataset_reader.py b/ajet/task_reader/hf_dataset_reader.py new file mode 100644 index 00000000..41136d81 --- /dev/null +++ b/ajet/task_reader/hf_dataset_reader.py @@ -0,0 +1,74 @@ +from typing import List + +import datasets + +from ajet.schema.task import Task +from ajet.task_reader.task_reader_base import BaseTaskReader + + +class HuggingFaceTaskReader(BaseTaskReader): + """ + Task reader that reads tasks from Hugging Face datasets. + + This class allows loading tasks directly from Hugging Face dataset repositories. + It supports configuring the dataset name and split names for training and validation. + """ + + def __init__(self, reader_config): + super().__init__(reader_config) + self.reader_config = reader_config + + def _load_dataset_split(self, dataset_name: str, split: str) -> List[Task]: + """ + Load a dataset split from Hugging Face datasets. + + Args: + dataset_name: Name of the dataset in Hugging Face format (e.g., 'gsm8k') + split: Name of the split to load (e.g., 'train', 'validation') + + Returns: + List[Task]: List of Task objects created from the dataset. + """ + try: + dataset = datasets.load_dataset(dataset_name, split=split) + except Exception as e: + raise ValueError(f"Failed to load dataset '{dataset_name}' with split '{split}': {str(e)}") + + # if len(dataset) == 0: + # raise ValueError(f"No examples found in dataset '{dataset_name}' with split '{split}'") + + tasks = [] + for idx, example in enumerate(dataset): + # Create Task object + task = Task( + main_query=example.get("question", "Empty"), + init_messages=[], # Dataset examples typically don't have init messages + task_id=str(idx), + env_type="no_env", + metadata=example, + ) + tasks.append(task) + + return tasks + + def get_training_tasks(self) -> List[Task]: + """ + Get training tasks from the Hugging Face dataset specified in the config. + + Returns: + List[Task]: List of training Task objects. + """ + dataset_name = self.reader_config.huggingface_dat_repo.dataset_path + split = self.reader_config.huggingface_dat_repo.training_split + return self._load_dataset_split(dataset_name, split) + + def get_validation_tasks(self) -> List[Task]: + """ + Get validation tasks from the Hugging Face dataset specified in the config. + + Returns: + List[Task]: List of validation Task objects. + """ + dataset_name = self.reader_config.huggingface_dat_repo.dataset_path + split = self.reader_config.huggingface_dat_repo.validation_split + return self._load_dataset_split(dataset_name, split) diff --git a/ajet/task_reader/jsonl_reader.py b/ajet/task_reader/jsonl_reader.py new file mode 100644 index 00000000..14462371 --- /dev/null +++ b/ajet/task_reader/jsonl_reader.py @@ -0,0 +1,66 @@ +import json +from typing import List + +from ajet.schema.task import Task +from ajet.task_reader.task_reader_base import BaseTaskReader + + +class JsonlTaskReader(BaseTaskReader): + def __init__(self, reader_config): + super().__init__(reader_config) + self.reader_config = reader_config + + def _read_jsonl_file(self, file_path): + """ + Read tasks from a JSONL file. + + Args: + file_path (str): Path to the JSONL file. + + Returns: + List[Task]: List of Task objects. + """ + tasks = [] + try: + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): # Skip empty lines + task_data = json.loads(line) + # Create a Task object from the JSON data + task = Task( + main_query=task_data.get("main_query", "[not defined]"), + init_messages=task_data.get("init_messages", []), + task_id=task_data.get("task_id", ""), + env_type=task_data.get("env_type", "no_env"), + metadata=task_data.get("metadata", task_data), + ) + tasks.append(task) + except FileNotFoundError: + raise ValueError(f"JSONL file not found: {file_path}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in file {file_path}: {str(e)}") + + if len(tasks) == 0: + raise ValueError(f"No tasks found in file: {file_path}") + + return tasks + + def get_training_tasks(self) -> List[Task]: + """ + Get training tasks from the JSONL file specified in the config. + + Returns: + List[Task]: List of training Task objects. + """ + file_path = self.reader_config.jsonl_dataset_file.training.file_path + return self._read_jsonl_file(file_path) + + def get_validation_tasks(self) -> List[Task]: + """ + Get validation tasks from the JSONL file specified in the config. + + Returns: + List[Task]: List of validation Task objects. + """ + file_path = self.reader_config.jsonl_dataset_file.validation.file_path + return self._read_jsonl_file(file_path) diff --git a/ajet/task_reader/task_reader_base.py b/ajet/task_reader/task_reader_base.py new file mode 100644 index 00000000..3f17911c --- /dev/null +++ b/ajet/task_reader/task_reader_base.py @@ -0,0 +1,14 @@ +from typing import List + +from ajet.schema.task import Task + + +class BaseTaskReader: + def __init__(self, reader_config): + self.reader_config = reader_config + + def get_training_tasks(self) -> List[Task]: + raise NotImplementedError + + def get_validation_tasks(self) -> List[Task]: + raise NotImplementedError diff --git a/ajet/task_reader/tracing_reader/__init__.py b/ajet/task_reader/tracing_reader/__init__.py new file mode 100644 index 00000000..2a98a556 --- /dev/null +++ b/ajet/task_reader/tracing_reader/__init__.py @@ -0,0 +1,123 @@ +import json +import os +import random +from typing import Any, List, Mapping, TypedDict + +from loguru import logger + +from ajet.schema.task import Task +from ajet.task_reader.tracing_reader.filters.base import Filter +from ajet.task_reader.tracing_reader.filters.factory import build_filters + +from ..task_reader_base import BaseTaskReader + + +class Config(TypedDict): + base_url: str + train_output_path: str + filters: List[Mapping[str, Any]] + + +class TracingReader(BaseTaskReader): + def __init__( + self, + reader_config, + train_ratio: float = 0.7, + split_seed: int = 42, + ) -> None: + from ajet.task_reader.tracing_reader.connector import ( + LocalSqliteConnectorV1, + ) + + super().__init__(reader_config) + # config patch + self.reader_config = reader_config.feedback_tracing + + logger.info(f"reading tasks from {self.reader_config.get('base_url')}, #filter {len(self.reader_config.get('filters', []))}") + self._connector = LocalSqliteConnectorV1(self.reader_config.get("base_url")) + filters_config = self.reader_config.get("filters") + built_filters = build_filters(filters_config) + self._filters: List[Filter] = built_filters + + self._train_ratio = train_ratio + self._split_seed = split_seed + + self._train_tasks: List[Task] = [] + self._val_tasks: List[Task] = [] + + self._init_tasks() + + def _load_existing_tasks(self, path: str) -> List[Task]: + if not os.path.exists(path): + return [] + tasks: List[Task] = [] + with open(path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + obj = json.loads(line) + tasks.append(Task(**obj)) + return tasks + + def _append_tasks(self, path: str, tasks: List[Task]) -> None: + if not tasks: + return + mode = "a" if os.path.exists(path) else "w" + with open(path, mode) as f: + for task in tasks: + obj = task.model_dump() + f.write(json.dumps(obj, ensure_ascii=False) + "\n") + + def _apply_filters(self, tasks: List[Task]) -> List[Task]: + filtered = tasks + for flt in self._filters: + filtered = flt.filter_sync(filtered) + return filtered + + def _init_tasks(self) -> None: + output_path = self.reader_config.get("train_output_path") + + tasks = self._connector.load_tasks_from_conversation() + logger.info(f"Loaded {len(tasks)} tasks from conversation") + + existing_tasks = self._load_existing_tasks(output_path) + existing_hashes = {t.metadata.get("qa_hash") for t in existing_tasks if t.metadata.get("qa_hash") is not None} + + new_tasks = [t for t in tasks if t.metadata.get("qa_hash") is not None and t.metadata["qa_hash"] not in existing_hashes] + + new_tasks_filtered = self._apply_filters(new_tasks) + + self._append_tasks(output_path, new_tasks_filtered) + + all_tasks: List[Task] = existing_tasks + new_tasks_filtered + + if not all_tasks: + self._train_tasks = [] + self._val_tasks = [] + return + + shuffled_tasks = list(all_tasks) + rnd = random.Random(self._split_seed) + rnd.shuffle(shuffled_tasks) + + total = len(shuffled_tasks) + train_size = int(total * self._train_ratio) + + if total == 1: + train_size = 1 + else: + if train_size <= 0: + train_size = 1 + if train_size >= total: + train_size = total - 1 + + self._train_tasks = shuffled_tasks[:train_size] + self._val_tasks = shuffled_tasks[train_size:] + logger.info(f"Shuffled {total} tasks into {train_size} train and {total - train_size} val") + + def get_training_tasks(self) -> List[Task]: + return self._train_tasks + + def get_validation_tasks(self) -> List[Task]: + return self._val_tasks diff --git a/ajet/task_reader/tracing_reader/connector.py b/ajet/task_reader/tracing_reader/connector.py new file mode 100644 index 00000000..6dc1f17c --- /dev/null +++ b/ajet/task_reader/tracing_reader/connector.py @@ -0,0 +1,260 @@ +import ast +import hashlib +import json +import os +import re +import sqlite3 +from datetime import datetime +from typing import List, Protocol + +import requests +from loguru import logger + +from ajet.schema.task import Task + + +class TracingConnector(Protocol): + def load_tasks_from_conversation(self) -> List[Task]: + ... + + +class PhoenixConnector: + """ + PhoneixConnector is a class that connects to the Phoneix API. + + Args: + base_url (str): The base URL of the Phoneix API. + projects_limit (int): The maximum number of projects to load. + spans_limit (int): The maximum number of spans to load from each project. + + Methods: + load_spans(self, projects_limit: int = 100, spans_limit: int = 100) -> list: + Load all spans from all projects. + + load_tasks_from_conversation(self) -> List[Task]: + Load all tasks from the conversation spans. + + Attributes: + _base_url (str): The base URL of the Phoneix API. + _projects_limit (int): The maximum number of projects to load. + _spans_limit (int): The maximum number of spans to load from each project. + """ + + def __init__(self, base_url: str, projects_limit: int = 100, spans_limit: int = 100) -> None: + self._base_url = base_url.rstrip("/") + self._projects_limit = projects_limit + self._spans_limit = spans_limit + + def _get(self, path: str, **params): + url = f"{self._base_url}{path}" + resp = requests.get(url, params=params) + resp.raise_for_status() + return resp.json()["data"] + + def load_spans(self, projects_limit: int = 100, spans_limit: int = 100) -> list: + projects = self._get( + "/v1/projects", + limit=projects_limit, + include_experiment_projects="false", + ) + + all_spans: list = [] + for project in projects: + pid = project["id"] + spans = self._get(f"/v1/projects/{pid}/spans", limit=spans_limit) + all_spans.extend(spans) + return all_spans + + def load_tasks_from_conversation(self) -> List[Task]: + all_spans = self.load_spans(projects_limit=self._projects_limit, spans_limit=self._spans_limit) + all_spans.sort(key=lambda x: datetime.fromisoformat(x["end_time"])) + all_spans = list(filter(lambda x: x["name"].startswith("invoke_agent"), all_spans)) + + qa: list = [] + for span in all_spans: + inp = json.loads(span["attributes"]["gen_ai.input.messages"]) + out = json.loads(span["attributes"]["gen_ai.output.messages"]) + if "parts" in inp and "parts" in out: + qa.append( + { + "query": inp["parts"][0]["content"], + "answer": out["parts"][0]["content"], + } + ) + + tasks: List[Task] = [] + for item in qa: + raw = (item["query"] or "") + "\n" + (item["answer"] or "") + qa_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest() + task = Task( + main_query=item["query"], + task_id="no_id", + env_type="no_env", + metadata={ + "answer": item["answer"], + "qa_hash": qa_hash, + }, + ) + tasks.append(task) + return tasks + + +def parse_msg_line(line: str): + """ + Extract role and content from Msg(...). + """ + match = re.search(r"Msg\((.*)\)", line, re.DOTALL) + if not match: + return None + + inner = match.group(1) + + kv_pairs = [] + for item in re.findall(r"(\w+)=((?:'.*?'|\[.*?\]|None))", inner): + key, val = item + kv_pairs.append(f"'{key}': {val}") + dict_like = "{" + ", ".join(kv_pairs) + "}" + + try: + data = ast.literal_eval(dict_like) + except Exception as e: + print("Parse failed:", e) + return None + + role = data.get("role") + content = data.get("content") + return {"role": role, "content": content} + + +class LocalSqliteConnectorV1: + """ + A connector that loads tasks from a SQLite database file. + + Args: + db_path (str): Path to the SQLite database file. + + Attributes: + _db_path (str): Path to the SQLite database file. + + Methods: + load_tasks_from_conversation (self) -> List[Task]: + Load tasks from a conversation in the SQLite database file. + """ + + def __init__(self, db_path: str) -> None: + self._db_path = db_path + assert os.path.exists(self._db_path), f"DB file {self._db_path} does not exist" + + def load_tasks_from_conversation(self) -> List[Task]: + conn = sqlite3.connect(self._db_path) + cursor = conn.cursor() + rows = cursor.execute("SELECT attributes FROM span_table where name='ReActAgent.reply'").fetchall() + logger.debug(f"Loaded {len(rows)} rows from {self._db_path}") + + qa = [] + for row in rows: + js = json.loads(row[0]) + query = js["input"]["kwargs"]["msg"] + output = js["output"] if "output" in js else None + if query is not None and output is not None: + query = parse_msg_line(query) + output = parse_msg_line(output) + # patch + if isinstance(output["content"], list): + output["content"] = output["content"][-1] + if isinstance(output["content"], dict): + output["content"] = output["content"]["text"] + if query is not None and output is not None: + if query["role"] == "user" and output["role"] == "assistant": + if query["content"] is not None and output["content"] is not None: + qa.append( + { + "query": query["content"], + "answer": output["content"], + } + ) + + conn.close() + + tasks: List[Task] = [] + for item in qa: + raw = (item["query"] or "") + "\n" + (item["answer"] or "") + qa_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest() + task = Task( + main_query=item["query"], + task_id="no_id", + env_type="no_env", + metadata={ + "answer": item["answer"], + "qa_hash": qa_hash, + }, + ) + tasks.append(task) + + return tasks + + +class LocalSqliteConnectorV2: + """ + A connector that loads tasks from a SQLite database file in new format. + + https://github.com/agentscope-ai/agentscope-studio/pull/40/files#diff-12c7e27505a5171e080133021430d8cae2f4929ce2f4c93bd4ea5a389094224a + + Args: + db_path (str): Path to the SQLite database file. + + Attributes: + _db_path (str): Path to the SQLite database file. + + Methods: + load_tasks_from_conversation (self) -> List[Task]: + Load tasks from a conversation in the SQLite database file. + """ + + def __init__(self, db_path: str) -> None: + self._db_path = db_path + assert os.path.exists(self._db_path), f"DB file {self._db_path} does not exist" + + def load_tasks_from_conversation(self) -> List[Task]: + conn = sqlite3.connect(self._db_path) + cursor = conn.cursor() + rows = cursor.execute("SELECT attributes FROM span_table where name='ReActAgent.reply'").fetchall() + + qa = [] + for row in rows: + js = json.loads(row[0]) + inp = json.loads(js["gen_ai"]["input"]["messages"]) + out = json.loads(js["gen_ai"]["output"]["messages"]) + if "parts" in inp and "parts" in out: + qa.append( + { + "query": inp["parts"][0]["content"], + "answer": out["parts"][0]["content"], + } + ) + + conn.close() + + tasks: List[Task] = [] + for item in qa: + raw = (item["query"] or "") + "\n" + (item["answer"] or "") + qa_hash = hashlib.sha256(raw.encode("utf-8")).hexdigest() + task = Task( + main_query=item["query"], + task_id="no_id", + env_type="no_env", + metadata={ + "answer": item["answer"], + "qa_hash": qa_hash, + }, + ) + tasks.append(task) + + return tasks + + +__all__ = [ + "LocalSqliteConnectorV1", + "LocalSqliteConnectorV2", + "PhoenixConnector", +] diff --git a/ajet/task_reader/tracing_reader/filters/__init__.py b/ajet/task_reader/tracing_reader/filters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ajet/task_reader/tracing_reader/filters/base.py b/ajet/task_reader/tracing_reader/filters/base.py new file mode 100644 index 00000000..fbdde8d0 --- /dev/null +++ b/ajet/task_reader/tracing_reader/filters/base.py @@ -0,0 +1,40 @@ +import asyncio +import threading +from abc import ABC, abstractmethod +from typing import Iterable, List + +from ajet.schema.task import Task + + +class Filter(ABC): + @abstractmethod + async def filter(self, tasks: Iterable[Task]) -> List[Task]: + """Filter a collection of Task objects and return the kept ones.""" + raise NotImplementedError + + def filter_sync(self, tasks: Iterable[Task]) -> List[Task]: + """This is a temp fix for async filter being called in a sync context.""" + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(self.filter(tasks)) + + res_holder: dict[str, object] = {} + err_holder: dict[str, BaseException] = {} + + def _runner() -> None: + try: + res_holder["res"] = asyncio.run(self.filter(tasks)) + except BaseException as e: + err_holder["err"] = e + + t = threading.Thread(target=_runner, daemon=True) + t.start() + t.join() + + if "err" in err_holder: + raise err_holder["err"] + + res = res_holder.get("res") + assert isinstance(res, list) + return res diff --git a/ajet/task_reader/tracing_reader/filters/deduplication_filter.py b/ajet/task_reader/tracing_reader/filters/deduplication_filter.py new file mode 100644 index 00000000..588a2a56 --- /dev/null +++ b/ajet/task_reader/tracing_reader/filters/deduplication_filter.py @@ -0,0 +1,44 @@ +import os +import shutil +from typing import Iterable, List + +from ajet.schema.task import Task +from ajet.utils.embedding_client import EmbeddingClient + +from .base import Filter + + +class DeduplicationFilter(Filter): + def __init__( + self, + similarity_threshold: float, + db_path: str, + model: str, + api_key: str | None, + base_url: str, + ): + # remove old db + if os.path.exists(db_path): + shutil.rmtree(db_path) + + self._client = EmbeddingClient( + similarity_threshold=similarity_threshold, + base_url=base_url, + api_key=api_key, + model=model, + chroma_db_path=db_path, + ) + + self._similarity_threshold = similarity_threshold + self._db_path = db_path + + async def filter(self, tasks: Iterable[Task]) -> List[Task]: + res = [] + for task in tasks: + similar = self._client.find_top_k_by_text(task.main_query, k=1) + if len(similar) != 0 and similar[0][1] >= self._similarity_threshold: + continue + res.append(task) + self._client.add(task.main_query, hash(task.main_query)) + + return res diff --git a/ajet/task_reader/tracing_reader/filters/factory.py b/ajet/task_reader/tracing_reader/filters/factory.py new file mode 100644 index 00000000..66251e0d --- /dev/null +++ b/ajet/task_reader/tracing_reader/filters/factory.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Mapping, MutableMapping, Sequence + +from .base import Filter +from .deduplication_filter import DeduplicationFilter +from .llm_evaluate_filter import LlmEvaluateFilter + +FILTER_REGISTRY: Dict[str, type[Filter]] = { + "llm_evaluate": LlmEvaluateFilter, + "deduplication": DeduplicationFilter, +} + + +def _build_single_filter(spec: Mapping[str, Any]) -> Filter: + type_name = spec.get("type") + if not isinstance(type_name, str): + raise ValueError(f"Filter spec must contain string 'type', got: {type_name!r}") + + params = spec.get("params") or {} + if not isinstance(params, MutableMapping): + raise TypeError("Filter 'params' must be a mapping if present") + + cls = FILTER_REGISTRY.get(type_name) + if cls is None: + raise ValueError(f"Unknown filter type: {type_name!r}") + + return cls(**params) # type: ignore[arg-type] + + +def build_filters(specs: Sequence[Mapping[str, Any]] | None) -> List[Filter]: + """Setup filter chain according to config.""" + if not specs: + return [] + + filters: List[Filter] = [] + for spec in specs: + enabled = spec.get("enabled", True) + if not enabled: + continue + filters.append(_build_single_filter(spec)) + return filters diff --git a/ajet/task_reader/tracing_reader/filters/llm_evaluate_filter.py b/ajet/task_reader/tracing_reader/filters/llm_evaluate_filter.py new file mode 100644 index 00000000..6cab2bd9 --- /dev/null +++ b/ajet/task_reader/tracing_reader/filters/llm_evaluate_filter.py @@ -0,0 +1,89 @@ +import os +from typing import Iterable, List + +from agentscope.agent import ReActAgent +from agentscope.formatter import DashScopeMultiAgentFormatter +from agentscope.message import Msg +from agentscope.model import DashScopeChatModel +from pydantic import BaseModel, Field + +from ajet.schema.task import Task +from ajet.task_rollout.dashscope_llm_bridge import create_external_llm_fn + +from .base import Filter + +EVALUATE_PROMPT = """You are now acting as a **strict QA quality reviewer**. You will be given a data sample containing a “query” (user question/task) and an “answer” (assistant reply). Evaluate it **only based on the text itself**, without inventing facts or performing external retrieval. + +--- + +## 1. Evaluation Goal +Determine whether the given “query-answer” pair is **high-quality data (GOOD)** and provide a score and justification. +If it does not meet the criteria, label it as **BAD**. + +--- + +## 2. BAD Criteria (if any are met → BAD) +1. **Missing elements**: The query is empty, the answer is empty, or both are empty. +2. **Non-answer**: The answer contains only acknowledgments such as “Received / OK / Please provide more information,” without substantive content or actionable results. +3. **Irrelevant**: The answer is clearly unrelated to the query. +4. **Process excuses**: The answer mainly describes process issues (“cannot search / rate-limited / captcha / try another device”), **without** providing alternative information, summaries, or next steps. +5. **Self-contradiction or illogical**: The answer contradicts itself or contains major logical inconsistencies. +6. **Safety or compliance violations**: Includes illegal content, hate speech, personal privacy leaks, or other clearly inappropriate material. +7. **Severe language mismatch**: The answer is in a completely different language from the query in a way that breaks comprehension (e.g., Chinese query but irrelevant and incoherent French reply). + +--- + +## 3. Special Cases & Additional Rules +{custom_rubrics} + +--- + +If **any** of the above conditions are triggered, the final result must be **BAD**. Otherwise, it is **GOOD**. +""" + + +class EvalResModel(BaseModel): + reason: str = Field( + description="judgment reason, briefly explain the reason", + ) + result: str = Field( + description="GOOD/BAD", + ) + + +class LlmEvaluateFilter(Filter): + def __init__( + self, + *, + custom_rubrics: str = "", + temperature: float = 0.7, + max_tokens: int = 2048, + print_reason: bool = True, + ) -> None: + """Filter that evaluates the quality of tasks using LLM.""" + + self._print_reason = print_reason + self.external_llm_fn = create_external_llm_fn( + alien_llm_model="qwen3-235b-a22b-instruct-2507", + alien_llm_response_length=512, + ) + self._fn = ReActAgent( + name="agent", + sys_prompt=EVALUATE_PROMPT.format(custom_rubrics=custom_rubrics), + model=DashScopeChatModel("qwen3-235b-a22b-instruct-2507", os.environ["DASHSCOPE_API_KEY"]), + formatter=DashScopeMultiAgentFormatter(), + max_iters=1, + ) + + async def filter(self, tasks: Iterable[Task]) -> List[Task]: + kept: List[Task] = [] + for task in tasks: + payload = "query: " + task.main_query + "\n" "answer: " + task.metadata.get("answer", "") + + res = await self._fn(Msg("user", content=payload, role="user"), structured_model=EvalResModel) + assert isinstance(res, EvalResModel) + if self._print_reason: + print(res.reason) + if res.result == "GOOD": + kept.append(task) + return kept diff --git a/ajet/task_reader/tracing_reader/llm_client.py b/ajet/task_reader/tracing_reader/llm_client.py new file mode 100644 index 00000000..e414dcdd --- /dev/null +++ b/ajet/task_reader/tracing_reader/llm_client.py @@ -0,0 +1,283 @@ +import json +import os +import time +from typing import Any, Generator, Optional, cast + +import requests +from loguru import logger + + +class LlmException(Exception): + def __init__(self, typ: str): + self._type = typ + + @property + def typ(self): + return self._type + + +class DashScopeClient: + """Aliyun DashScope API Client""" + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "qwen-plus", + temperature: float = 0.7, + max_tokens: int = 2048, + ): + self.api_key = api_key or os.getenv("DASHSCOPE_API_KEY") + if not self.api_key: + raise ValueError("API key is required. Please set DASHSCOPE_API_KEY environment variable or pass it directly.") + + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + self.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + def set_model(self, model_name: str): + """ + Sets the model name for the DashScopeClient instance. + + Args: + model_name (str): The name of the model to be used for API interactions. + """ + self.model_name = model_name # ⭐ Assigns the provided model name to the instance variable + + def chat(self, messages: list[dict[str, str]], sampling_params: dict[str, Any]) -> str: + """ + Sends a chat request to the LLM, aggregates the streaming responses, and returns the complete response. + + Args: + messages (list[dict[str, str]]): A list of message dictionaries, each containing 'role' and 'content'. + sampling_params (dict[str, Any]): Parameters for controlling the sampling behavior of the LLM. + + Returns: + str: The complete response from the LLM as a single string. + """ + res = "" + for x in self.chat_stream(messages, sampling_params): # ⭐ Aggregates the streaming responses into a single string + res += x + return res + + def chat_stream(self, messages: list[dict[str, str]], sampling_params: dict[str, Any]) -> Generator[str, None, None]: + """ + Initiates a streaming chat session and returns a generator that yields the response as it is being generated. + + Args: + messages (list[dict[str, str]]): A list of message objects, each containing 'role' and 'content'. + sampling_params (dict[str, Any]): Parameters for controlling the sampling behavior of the model. + + Returns: + Generator[str, None, None]: A generator that yields the response text as it is being generated. + """ + return self.chat_stream_with_retry(messages, **sampling_params) # ⭐ Calls the retry mechanism for streaming chat + + def chat_completion(self, messages: list[dict[str, str]], stream: bool = False, **kwargs) -> str | Generator[str, None, None]: + """ + Sends a request to the chat completion API, supporting both non-streaming and streaming modes, and handles various exceptions. + + Args: + messages (list[dict[str, str]]): A list of message objects, each containing 'role' and 'content'. + stream (bool, optional): If True, the response will be streamed. Defaults to False. + **kwargs: Additional parameters to be passed to the API. + + Returns: + str | Generator[str, None, None]: The full response text if not streaming, or a generator yielding the response text if streaming. + """ + url = f"{self.base_url}/chat/completions" + + # Merge parameters + params = { + "model": self.model_name, + "messages": messages, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "stream": stream, + **kwargs, + } + + try: + if stream: + return self._handle_stream_response(url, params) # ⭐ Handles the streaming response + else: + return self._handle_normal_response(url, params) # ⭐ Handles the non-streaming response + + except requests.exceptions.RequestException as e: + logger.error(f"API request failed: {e}") + return "" if not stream else (x for x in []) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse API response: {e}") + return "" if not stream else (x for x in []) + except Exception as e: + logger.error(f"Unexpected error in API call: {e}") + return "" if not stream else (x for x in []) + + def _handle_normal_response(self, url: str, params: dict) -> str: + """ + Handles the non-streaming (normal) response from the API. + + Args: + url (str): The URL to which the POST request is sent. + params (dict): The parameters to be included in the JSON body of the POST request. + + Returns: + str: The content of the first choice's message in the response, or an empty string if the response format is unexpected. + """ + response = requests.post(url, headers=self.headers, json=params, timeout=600) # ⭐ Sends the POST request to the API + if not response.ok: + # check inappropriate content + try: + error_json = response.json()["error"] + if "inappropriate content" in error_json["message"]: + raise LlmException("inappropriate content") + if "limit" in error_json["message"]: + raise LlmException("hit limit") + except LlmException: + raise + except Exception: + logger.error(f"API request failed: {response.text}") + response.raise_for_status() + + result = response.json() + if "choices" in result and len(result["choices"]) > 0: + return result["choices"][0]["message"]["content"].strip() # ⭐ Extracts and returns the content of the first choice's message + else: + logger.error(f"Unexpected response format: {result}") + return "" + + def _handle_stream_response(self, url: str, params: dict) -> Generator[str, None, None]: + """ + Handles the streaming response from a POST request to the specified URL. + + Args: + url (str): The URL to which the POST request is sent. + params (dict): The parameters to be sent with the POST request. + + Yields: + str: The content of the response, if it meets the specified conditions. + """ + response = requests.post(url, headers=self.headers, json=params, stream=True, timeout=600) # ⭐ Send the POST request and get the streaming response + if not response.ok: + # check inappropriate content + try: + error_json = response.json()["error"] + if "inappropriate content" in error_json["message"]: + raise LlmException("inappropriate content") + if "limit" in error_json["message"]: + raise LlmException("hit limit") + except LlmException: + raise + except Exception: + logger.error(f"API request failed: {response.text}") + response.raise_for_status() + + for line in response.iter_lines(): + if line: + line = line.decode("utf-8") + if line.startswith("data: "): + data = line[6:] # remove the prefix 'data: ' + if data == "[DONE]": + break + + try: + chunk = json.loads(data) + if "choices" in chunk and len(chunk["choices"]) > 0: + choice = chunk["choices"][0] + if "delta" in choice and "content" in choice["delta"]: + content = choice["delta"]["content"] + if content: + yield content # ⭐ Yield the content if it meets the conditions + except json.JSONDecodeError: + continue # skip the bad line + + def chat_with_retry( + self, + messages: list[dict[str, str]], + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> str: + """ + Sends a chat completion request to the LLM with a retry mechanism. + + Args: + messages (list[dict[str, str]]): A list of message dictionaries for the chat. + max_retries (int, optional): Maximum number of retries. Defaults to 3. + retry_delay (float, optional): Initial delay between retries in seconds. Defaults to 1.0. + **kwargs: Additional keyword arguments to be passed to the `chat_completion` method. + + Returns: + str: The response from the LLM or a predefined message if all attempts fail. + """ + for attempt in range(max_retries): + try: + result = cast(str, self.chat_completion(messages, stream=False, **kwargs)) # ⭐ Attempt to get a chat completion + if result: # If a valid response is obtained + return result + + except LlmException as e: + if e.typ == "inappropriate content": + logger.warning("llm return inappropriate content, which is blocked by the remote") + return "[inappropriate content]" + except Exception as e: + logger.warning(f"Attempt {attempt + 1} failed: {e}") + + if attempt < max_retries - 1: # Not the last attempt + time.sleep(retry_delay * (2**attempt)) # Exponential backoff + + logger.error(f"All {max_retries} attempts failed") + return "" + + def chat_stream_with_retry( + self, + messages: list[dict[str, str]], + max_retries: int = 3, + retry_delay: float = 10.0, + **kwargs, + ) -> Generator[str, None, None]: + """ + Attempts to establish a streaming chat completion with a retry mechanism. + + Args: + messages (list[dict[str, str]]): A list of message dictionaries, each containing 'role' and 'content'. + max_retries (int, optional): The maximum number of retry attempts. Defaults to 3. + retry_delay (float, optional): The initial delay in seconds before the first retry. Defaults to 10.0. + **kwargs: Additional keyword arguments to pass to the chat_completion method. + + Yields: + str: Chunks of the streaming response. + """ + for attempt in range(max_retries): + try: + stream_generator = cast( + Generator[str, None, None], + self.chat_completion(messages, stream=True, **kwargs), + ) # ⭐ Cast the generator to the appropriate type + # try to fetch the first chunk to verify the connection + first_chunk = next(stream_generator, None) + if first_chunk is not None: + yield first_chunk + # yield the rest chunks + for chunk in stream_generator: + yield chunk + return # success + except LlmException as e: + if e.typ == "inappropriate content": + logger.warning("llm return inappropriate content, which is blocked by the remote") + yield "[inappropriate content]" + return + except Exception as e: + logger.warning(f"Stream attempt {attempt + 1} failed: {e}") + + if attempt < max_retries - 1: + time.sleep(retry_delay * (2**attempt)) + + logger.error(f"All {max_retries} stream attempts failed") + + return diff --git a/ajet/task_rollout/__init__.py b/ajet/task_rollout/__init__.py new file mode 100644 index 00000000..b718083c --- /dev/null +++ b/ajet/task_rollout/__init__.py @@ -0,0 +1 @@ +"""Task rollout helper modules.""" diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py new file mode 100644 index 00000000..cafdffc2 --- /dev/null +++ b/ajet/task_rollout/async_llm_bridge.py @@ -0,0 +1,549 @@ +import asyncio +import copy +import json +import time +import uuid +from typing import Any, Callable, Dict, List, Literal, Type, Union + + +from loguru import logger +from omegaconf import DictConfig +from pydantic import BaseModel +from transformers.tokenization_utils import PreTrainedTokenizer +from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser +from vllm.outputs import RequestOutput as VerlVllmRequestOutput + +from agentscope.model import ChatResponse as AgentScopeChatResponse +from openai.types.chat.chat_completion import ChatCompletion as OpenAIChatCompletion + +ChatResponse = Union[OpenAIChatCompletion, AgentScopeChatResponse] + +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) +from ajet.schema.convertion import convert_llm_proxy_response_to_oai_response +from ajet.schema.convertion import convert_llm_proxy_response_to_agentscope_response +from ajet.schema.logprob import TokenAndProb +from ajet.utils.async_utils import run_async_coroutine_with_timeout +from ajet.utils.testing_utils import _mock_if_test_mode, _test_if_test_mode +from ajet.utils.tokenizer import ajet_apply_chat_template + + +class AjetStandardLlmBridgeRequest(BaseModel): + messages: List[Dict[str, str]] + custom_sampling_params: dict = {} + tools: List = [] + request_id: str = "" + + +class AjetStandardLlmBridgeResponse(BaseModel): + role: str = "assistant" + request_id: str = "" + content: str = "" + tool_calls: List[Dict] = [] + tokens: List[TokenAndProb] = [] + + +# ------------------------------------------------------------------------------------- +# ------------------------ Unify LLM for Verl + Trinity + Vllm ------------------------ +# ------------------------------------------------------------------------------------- + + +class AsyncLlmBridge(object): + def __init__( + self, + config: DictConfig, + async_rollout_manager: Any, + tokenizer: Any, + llm_mode: Literal["local", "remote", "trinity"] = "local", + max_llm_retries: int = 3, + ): + self.config = config + self.async_rollout_manager = async_rollout_manager + self.tokenizer = tokenizer + self.llm_mode = llm_mode + self.max_llm_retries = max_llm_retries + self.tool_parser = Hermes2ProToolParser(self.tokenizer) + + def get_llm_inference_fn_sync(self, sampling_params: dict = {}) -> Callable: # noqa: C901 + def llm_chat_verl( + messages: List[Dict[str, str]], + custom_sampling_params: dict = {}, + tools=[], + request_id: str = "", + ) -> dict: + request_id = uuid.uuid4().hex + + updated_sampling_params = {} + if sampling_params: + updated_sampling_params.update(sampling_params) + if custom_sampling_params: + updated_sampling_params.update(custom_sampling_params) + + input_messages = copy.deepcopy(messages) + prompt_text = ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=input_messages, + tools=tools, + add_generation_prompt=True, + tokenize=False, + ) + prompt_ids = self.tokenizer(prompt_text)["input_ids"] + + if self.config.ajet.execute_test: + _test_if_test_mode("prompt_text", prompt_text, self.config) + + final_res = run_async_coroutine_with_timeout( + self.async_rollout_manager.generate( + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=updated_sampling_params, + ), + timeout=1800, + ) + + if self.config.ajet.rollout.name == "vllm": + final_res: VerlVllmRequestOutput + token_array = final_res.outputs[0].token_ids + logprob_array = final_res.outputs[0].logprobs + elif self.config.ajet.rollout.name == "sglang": + token_array = final_res + + decoded_text = self.tokenizer.decode(token_array) # type: ignore + if self.config.ajet.execute_test: + decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config) + + if decoded_text.endswith("<|im_end|>"): + decoded_text = decoded_text[: -len("<|im_end|>")] + + # if tool call + tool_calls = None + if ("" in decoded_text) and ("" in decoded_text) and (not self.config.ajet.rollout.force_disable_toolcalls): + parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore + parsed_tool_calls = parsed_tool_calls.model_dump() + if self.config.ajet.execute_test: + _test_if_test_mode("parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config) + model_called = parsed_tool_calls["tools_called"] + if model_called: + tool_calls = parsed_tool_calls["tool_calls"] + is_bad_toolcall = False + for i in range(len(tool_calls)): + if "function" in tool_calls[i] and "arguments" in tool_calls[i]["function"]: + expect_dict = json.loads(tool_calls[i]["function"]["arguments"]) + if not isinstance(expect_dict, dict): + is_bad_toolcall = True + if is_bad_toolcall: + tool_calls = None + decoded_text = decoded_text + else: + decoded_text = parsed_tool_calls["content"] + if decoded_text is None: + decoded_text = "" + + return { + "role": "assistant", + "request_id": request_id, + "content": decoded_text, + "tool_calls": tool_calls, + "tokens": [ + TokenAndProb( + token_id=token_id, + logprob=logprob[token_id].logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only. + decoded_string=logprob[token_id].decoded_token, + ) + for token_id, logprob in zip(token_array, logprob_array) # type: ignore + ], + } + + def llm_chat_remote( + messages: List[Dict[str, str]], + custom_sampling_params: dict = {}, + tools=[], + request_id: str = "", + ) -> dict: + updated_sampling_params = {} + if sampling_params: + updated_sampling_params.update(sampling_params) + if custom_sampling_params: + updated_sampling_params.update(custom_sampling_params) + updated_sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) + input_messages = copy.deepcopy(messages) + for i in range(self.max_llm_retries): + try: + # this function is defined in `ajet/backbone/main_vllm.py` + output_message = self.async_rollout_manager.submit_chat_completions( + messages=input_messages, + sampling_params=updated_sampling_params, + tools=tools, + request_id=request_id, + ) + break + except Exception as e: + logger.bind(exception=True).exception(f"rollout_server.{i} error: {e.args}") + time.sleep(i + 1) + return output_message[-1] # type: ignore + + def llm_chat_trinity( + messages: List[Dict[str, str]], + custom_sampling_params: dict = {}, + tools=[], + request_id: str = "", + ) -> dict: + async def main(): + updated_sampling_params = {} + if sampling_params: + updated_sampling_params.update(sampling_params) + if custom_sampling_params: + updated_sampling_params.update(custom_sampling_params) + updated_sampling_params.pop("min_tokens") + + if tools: + response = await self.async_rollout_manager.chat.completions.create( + model=self.async_rollout_manager.model_path, + messages=messages, + logprobs=True, + tools=tools, + top_logprobs=0, + **updated_sampling_params, + ) + else: + response = await self.async_rollout_manager.chat.completions.create( + model=self.async_rollout_manager.model_path, + messages=messages, + logprobs=True, + top_logprobs=0, + **updated_sampling_params, + ) + return response + + response = run_async_coroutine_with_timeout(main(), timeout=1800) # type: ignore + prompt_text = self.tokenizer.decode(response.model_extra["prompt_token_ids"]) + prompt_token_ids = response.model_extra["prompt_token_ids"] + content = response.choices[0].message.content + message = response.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + + if content is None: + content = "" + + if ("" in content) and (not message.get("tool_calls", None)): + # logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}") + logger.warning(f"Bad toolcall discovered: {content}") + + return { + "role": "assistant", + "request_id": response.id, + "content": content, + "prompt_text": prompt_text, + "prompt_token_ids": prompt_token_ids, + "tool_calls": message.get("tool_calls", []), + "tokens": [ + TokenAndProb( + token_id=token, + logprob=tokenlogprob.logprob, # Warning: vllm logprob does not participant training, for log only. + decoded_string=tokenlogprob.token, + ) + for tokenlogprob, token in zip( + response.choices[0].logprobs.content, + response.choices[0].token_ids, + ) + ], + } + + if self.llm_mode == "remote": + return llm_chat_remote + if self.llm_mode == "trinity": + return llm_chat_trinity + else: + return llm_chat_verl + + def get_llm_inference_fn_async(self, sampling_params: dict = {}) -> Callable: # noqa: C901 + async def llm_chat_verl( + messages: List[Dict[str, str]], + custom_sampling_params: dict = {}, + tools=[], + request_id: str = "", + ) -> dict: + request_id = uuid.uuid4().hex + + updated_sampling_params = {} + if sampling_params: + updated_sampling_params.update(sampling_params) + if custom_sampling_params: + updated_sampling_params.update(custom_sampling_params) + + input_messages = copy.deepcopy(messages) + prompt_text = ajet_apply_chat_template( + tokenizer=self.tokenizer, + conversation=input_messages, + tools=tools, + add_generation_prompt=True, + tokenize=False, + ) + prompt_ids = self.tokenizer(prompt_text)["input_ids"] + + if self.config.ajet.execute_test: + _test_if_test_mode("prompt_text", prompt_text, self.config) + + final_res = await self.async_rollout_manager.generate( + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=updated_sampling_params, + ) + + if self.config.ajet.rollout.name == "vllm": + final_res: VerlVllmRequestOutput + token_array = final_res.outputs[0].token_ids + logprob_array = final_res.outputs[0].logprobs + elif self.config.ajet.rollout.name == "sglang": + token_array = final_res + + decoded_text = self.tokenizer.decode(token_array) # type: ignore + if self.config.ajet.execute_test: + decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config) + + if decoded_text.endswith("<|im_end|>"): + decoded_text = decoded_text[: -len("<|im_end|>")] + + # if tool call + tool_calls = None + if ("" in decoded_text) and ("" in decoded_text) and (not self.config.ajet.rollout.force_disable_toolcalls): + parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore + parsed_tool_calls = parsed_tool_calls.model_dump() + if self.config.ajet.execute_test: + _test_if_test_mode("parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config) + model_called = parsed_tool_calls["tools_called"] + if model_called: + tool_calls = parsed_tool_calls["tool_calls"] + is_bad_toolcall = False + for i in range(len(tool_calls)): + if "function" in tool_calls[i] and "arguments" in tool_calls[i]["function"]: + expect_dict = json.loads(tool_calls[i]["function"]["arguments"]) + if not isinstance(expect_dict, dict): + is_bad_toolcall = True + if is_bad_toolcall: + tool_calls = None + decoded_text = decoded_text + else: + decoded_text = parsed_tool_calls["content"] + if decoded_text is None: + decoded_text = "" + + return { + "role": "assistant", + "request_id": request_id, + "content": decoded_text, + "tool_calls": tool_calls, + "tokens": [ + TokenAndProb( + token_id=token_id, + logprob=logprob[token_id].logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only. + decoded_string=logprob[token_id].decoded_token, + ) + for token_id, logprob in zip(token_array, logprob_array) # type: ignore + ], + } + + async def llm_chat_remote( + messages: List[Dict[str, str]], + custom_sampling_params: dict = {}, + tools=[], + request_id: str = "", + ) -> dict: + updated_sampling_params = {} + if sampling_params: + updated_sampling_params.update(sampling_params) + if custom_sampling_params: + updated_sampling_params.update(custom_sampling_params) + updated_sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) + input_messages = copy.deepcopy(messages) + for i in range(self.max_llm_retries): + try: + # this function is defined in `ajet/backbone/main_vllm.py` + output_message = await self.async_rollout_manager.submit_chat_completions_async( + messages=input_messages, + sampling_params=updated_sampling_params, + tools=tools, + request_id=request_id, + ) + break + except Exception as e: + logger.bind(exception=True).exception(f"rollout_server.{i} error: {e.args}") + time.sleep(i + 1) + return output_message[-1] # type: ignore + + async def llm_chat_trinity( + messages: List[Dict[str, str]], + custom_sampling_params: dict = {}, + tools=[], + request_id: str = "", + ) -> dict: + async def main(): + updated_sampling_params = {} + if sampling_params: + updated_sampling_params.update(sampling_params) + if custom_sampling_params: + updated_sampling_params.update(custom_sampling_params) + updated_sampling_params.pop("min_tokens") + + if tools: + response = await self.async_rollout_manager.chat.completions.create( + model=self.async_rollout_manager.model_path, + messages=messages, + logprobs=True, + tools=tools, + top_logprobs=0, + **updated_sampling_params, + ) + else: + response = await self.async_rollout_manager.chat.completions.create( + model=self.async_rollout_manager.model_path, + messages=messages, + logprobs=True, + top_logprobs=0, + **updated_sampling_params, + ) + return response + + response = await main() + prompt_text = self.tokenizer.decode(response.model_extra["prompt_token_ids"]) + prompt_token_ids = response.model_extra["prompt_token_ids"] + content = response.choices[0].message.content + message = response.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) + + if content is None: + content = "" + + if ("" in content) and (not message.get("tool_calls", None)): + # logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}") + logger.warning(f"Bad toolcall discovered: {content}") + + return { + "role": "assistant", + "request_id": response.id, + "content": content, + "prompt_text": prompt_text, + "prompt_token_ids": prompt_token_ids, + "tool_calls": message.get("tool_calls", []), + "tokens": [ + TokenAndProb( + token_id=token, + logprob=tokenlogprob.logprob, # Warning: vllm logprob does not participant training, for log only. + decoded_string=tokenlogprob.token, + ) + for tokenlogprob, token in zip( + response.choices[0].logprobs.content, + response.choices[0].token_ids, + ) + ], + } + + if self.llm_mode == "remote": + return llm_chat_remote + if self.llm_mode == "trinity": + return llm_chat_trinity + else: + return llm_chat_verl + + +# ---------------------------------------------------------------------------------------------- +# ------------------------ call async llm with context tracker (OpenAI) ------------------------ +# ---------------------------------------------------------------------------------------------- + + +class OpenaiLlmProxyWithTracker(object): + """ + An essential wrapper to connect AsyncLlmBridge with AgentScope + + User_user_workflow <-> AsyncLlmBridge <-> Context Tracker. + """ + + def __init__( + self, + llm_inference_fn: Callable, # Callable[AjetStandardLlmBridgeRequest, AjetStandardLlmBridgeResponse] + context_tracker: MultiAgentContextTracker, + config, + ) -> None: + self.context_tracker = context_tracker + self.llm_inference_fn = llm_inference_fn + self.config = config + + async def __call__( + self, + messages: List[dict], + tools: List = [], + tool_choice: str = "auto", + structured_model=None, + **kwargs, + ) -> ChatResponse: + llm_output = await self.run_infer(messages, tools, tool_choice, structured_model, **kwargs) + return convert_llm_proxy_response_to_oai_response(llm_output) + + async def run_infer( + self, + messages: List[dict], + tools: List = [], + tool_choice: str = "auto", # always auto + structured_model=None, # this is for AgentScope only + **kwargs, + ): + # generate timeline uuid + timeline_uuid = uuid.uuid4().hex + + # prepare context tracker, check context safety + ( + context_safe, + token_overflow, + info, + converted_message, + custom_sampling_params, + tools, + ) = self.context_tracker.step_prepare(messages, tools, timeline_uuid=timeline_uuid) + + # if context not safe to infer further + if not context_safe: + logger.warning(f"[{info}] detected.") + self.context_tracker.context_overflow = True + if token_overflow: + # ajet_action_when_overflow = self.config.ajet.rollout.ajet_action_when_overflow + # cannot proceed due to context overflow + return self.construct_overflow_response() + # else: + # otherwise, for abnormal output, can still proceed, but we do not track output anymore + + # run llm inference ✨ + if self.config.ajet.task_runner.llm_infer_submit_method == "sync": + llm_output = await asyncio.to_thread(self.llm_inference_fn, converted_message, custom_sampling_params, tools) + else: + llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools) + + # begin context tracking + self.context_tracker.step_track(llm_output, context_safe, converted_message, tools, timeline_uuid=timeline_uuid) + return llm_output + + def construct_overflow_response(self): + return { + "role": "assistant", + "request_id": "overflow_response", + "content": "ajet_proxy: Exceeded max model context length.", + "tool_calls": None, + "tokens": [], + } + + +# ---------------------------------------------------------------------------------------------- +# ------------------------ call async llm with context tracker (AgentScope) -------------------- +# ---------------------------------------------------------------------------------------------- + + +class AgentScopeLlmProxyWithTracker(OpenaiLlmProxyWithTracker): + async def __call__( + self, + messages: List[dict], + tools: List = [], + tool_choice: str = "auto", + structured_model=None, + **kwargs, + ) -> AgentScopeChatResponse: + llm_output = await self.run_infer(messages, tools, tool_choice, structured_model) + response = convert_llm_proxy_response_to_agentscope_response(llm_output, structured_model=structured_model) + return response diff --git a/astune/context_manager/cmt_foreign_llm.py b/ajet/task_rollout/dashscope_llm_bridge.py similarity index 68% rename from astune/context_manager/cmt_foreign_llm.py rename to ajet/task_rollout/dashscope_llm_bridge.py index aa4cc24b..321dd572 100644 --- a/astune/context_manager/cmt_foreign_llm.py +++ b/ajet/task_rollout/dashscope_llm_bridge.py @@ -1,24 +1,30 @@ +import os import random import time -import os from textwrap import dedent -from openai import OpenAI + from loguru import logger +from openai import OpenAI -def construct_alien_llm_chat_fn(config, *args): - def alien_llm_chat_fn(messages, request_id=""): + +def create_external_llm_fn(alien_llm_model, alien_llm_response_length): + def external_llm_chat_fn(messages, sampling_params_override={}, request_id=""): max_try = 4 - alien_model_name = config.astune.context_manager.alien_llm_model - alien_model_response_length = config.astune.context_manager.alien_llm_response_length + alien_model_name = alien_llm_model + alien_model_response_length = alien_llm_response_length if os.environ.get("DASHSCOPE_API_KEY") is None or os.environ.get("DASHSCOPE_API_KEY_BACKUP") is None: - raise RuntimeError(dedent(""" + raise RuntimeError( + dedent( + """ Please set the DASHSCOPE_API_KEY and DASHSCOPE_API_KEY_BACKUP environment variables. You can get the API keys from https://www.dashscope.com/. Example: export DASHSCOPE_API_KEY='sk-xxxxxx|sk-yyyyyy' export DASHSCOPE_API_KEY_BACKUP='sk-zzzzzz' - """)) + """ + ) + ) regular_key_list = os.environ.get("DASHSCOPE_API_KEY") backup_key_list = os.environ.get("DASHSCOPE_API_KEY_BACKUP") @@ -31,15 +37,14 @@ def alien_llm_chat_fn(messages, request_id=""): else: backup_key_list = [] - for n_try in range(max_try): try: if n_try < max_try // 2: - api_key=random.choice(regular_key_list) + api_key = random.choice(regular_key_list) elif n_try == max_try // 2: - api_key=random.choice(backup_key_list) + api_key = random.choice(backup_key_list) else: - api_key=random.choice(regular_key_list + backup_key_list) + api_key = random.choice(regular_key_list + backup_key_list) client = OpenAI( api_key=api_key, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", @@ -47,19 +52,22 @@ def alien_llm_chat_fn(messages, request_id=""): sampling_params = dict( n=1, max_completion_tokens=alien_model_response_length, + temperature=0.0, ) - sampling_params["temperature"] = 0 + sampling_params.update(sampling_params_override) completion = client.chat.completions.create( model=alien_model_name, messages=messages, - extra_body=sampling_params + extra_body=sampling_params, ) message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) - if "content" not in message: message["content"] = "" - return {"role": message["role"], "content": message['content']} + if "content" not in message: + message["content"] = "" + return {"role": message["role"], "content": message["content"]} except Exception as e: logger.bind(exception=True).exception(f"Error calling alien llm: {e}") + logger.warning(f"Error calling alien llm: {e}, retrying...") time.sleep(5) - print(f"Error calling alien llm: {e}, retrying...") raise RuntimeError(f"Failed to get response from alien llm after {max_try} attempts") - return alien_llm_chat_fn + + return external_llm_chat_fn diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py new file mode 100644 index 00000000..b8aae915 --- /dev/null +++ b/ajet/task_rollout/native_parallel_worker.py @@ -0,0 +1,568 @@ +"""Parallel environment rollout orchestration utilities.""" + +import os +import time +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Dict, List, Literal +from urllib.parse import quote + +import numpy as np +import torch +from loguru import logger +from tensordict import TensorDict +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm +from verl import DataProto +from verl.utils.torch_functional import pad_sequence_to_length + +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.schema.task import Task +from ajet.schema.trajectory import Sample +from ajet.task_rollout.single_worker import BaseRolloutManager + + +class DynamicRolloutManager(BaseRolloutManager): + """Dynamic rollout supporting oversampling and early termination.""" + + def step_status_printer(self, observation_window): + """Pretty-print thread progress statistics for the shared obs window.""" + # Histogram buckets: observation_window['step'] 0~5 / 5~10 / 10~15 / ... + step_counter = {} + current_token = sum(observation_window["token"]) + current_time = time.time() + delta_token = current_token - self.current_token + if delta_token < 0: + delta_token = current_token + delta_time = current_time - self.current_token_count_time + self.current_token = current_token + self.current_token_count_time = current_time + token_gen_per_sec_str = f"{delta_token/delta_time:.2f} tokens/s" if delta_time > 0 else "N/A" + + for step in observation_window["step"]: + if step == -1: + step_counter[(-1, "terminated")] = step_counter.get((-1, "terminated"), 0) + 1 + continue + else: + start = (step // 5) * 5 + end = start + 5 + step_counter[(start, end)] = step_counter.get((start, end), 0) + 1 + + step_counter = dict(sorted(step_counter.items(), key=lambda x: x[0][0])) + + print_buf = [] + for (start, end), count in step_counter.items(): + if start != -1: + print_buf += [f"[{start}-{end}]:{count} threads"] + for (start, end), count in step_counter.items(): + if start == -1: + print_buf += [f"[finished]:{count} threads"] + print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf)) + + def rollout_static( + self, + tasks: List[Task], + mode: Literal["sample", "validate"], + epoch: str, + ) -> List[BaseContextTracker]: + """Execute non-dynamic rollouts in parallel and return collected trackers.""" + self.current_token_count_time = time.time() + tracker_array: List[BaseContextTracker] = [] + rollout_n = 1 if mode == "validate" else self.rollout_n + observation_window = { + "step": [0 for _ in range(len(tasks) * rollout_n)], + "token": [0 for _ in range(len(tasks) * rollout_n)], + "stop": [False for _ in range(len(tasks) * rollout_n)], + } + with ThreadPoolExecutor(max_workers=self.max_parallel) as executor: + futures: List[Future] = [] + for task_batch_index, task in enumerate(tasks): + for task_rollout_index in range(rollout_n): + task_thread_index = task_batch_index * rollout_n + task_rollout_index + future = executor.submit( + self.rollout_env_worker, + task=task, + task_batch_index=task_batch_index, + task_tag=f"T{task.task_id}#R{task_rollout_index}", + mode=mode, + task_thread_index=task_thread_index, + observation_window=observation_window, + ) + futures.append(future) + + while True: + if not any(future.running() for future in futures): + break + + completed_futures = [f for f in futures if f.done()] + failed_futures = [f for f in completed_futures if f.exception() is not None] + + if failed_futures: + executor.shutdown(wait=False, cancel_futures=True) + for f in futures: + if not f.done(): + f.cancel() + + for f in failed_futures: + logger.error(f"Thread failed with exception: {f.exception()}") + + raise RuntimeError(f"One of the rollout threads has encountered an exception. {len(failed_futures)} threads failed.") + + self.step_status_printer(observation_window) + time.sleep(10) + + for future in tqdm(futures, desc=f"epoch{epoch}.collect_rollout"): + result = future.result() + tracker_array.append(result) + + # TODO: support multi-step reward + task_success_rate = np.mean([tracker.reward_structure.success_rate for tracker in tracker_array]) + task_scalar_reward = np.mean([tracker.reward_structure.final_scalar_reward for tracker in tracker_array]) + + for tracker in tracker_array: + tracker.current_batch_success_rate = float(task_success_rate) + tracker.current_batch_reward = float(task_scalar_reward) + + return tracker_array + + def rollout( + self, + tasks: List[Task], + mode: Literal["sample", "validate"], + epoch: str, + ) -> List[BaseContextTracker]: + """Delegate to dynamic rollout when oversampling is enabled.""" + if mode == "sample" and (self.rollout_n != 1) and self.config.ajet.rollout.enable_oversample: + return self.rollout_dynamic(tasks, mode, epoch) + else: + return self.rollout_static(tasks, mode, epoch) + + def greedy_max_std_selection(self, samples: List[BaseContextTracker], n): + """Select samples whose rewards maximize spread to cover diverse rollouts.""" + if len(samples) < n: + additional_n = n - len(samples) + n = len(samples) + else: + additional_n = 0 + + sorted_samples = sorted( + samples, + key=lambda tracker: abs(tracker.reward_structure.performance_reward), + ) + value_array = [tracker.reward_structure.performance_reward for tracker in sorted_samples] + macro_selected_value = [] + macro_selected_index = [] + while len(macro_selected_index) != n: + selected_value = [] + selected_index = [] + for index, value in enumerate(value_array): + if (value not in selected_value) and (index not in macro_selected_index): + selected_value.append(value) + selected_index.append(index) + + if len(selected_value) + len(macro_selected_value) <= n: + macro_selected_value += selected_value + macro_selected_index += selected_index + + elif len(selected_value) + len(macro_selected_value) > n: + preserve_n = n - len(macro_selected_value) + pick_left = preserve_n // 2 + pick_right = preserve_n - pick_left + macro_selected_value += selected_value[:pick_left] + selected_value[-pick_right:] + macro_selected_index += selected_index[:pick_left] + selected_index[-pick_right:] + + if additional_n > 0: + additional_indices = np.random.choice(macro_selected_index, additional_n, replace=True) + macro_selected_index += additional_indices.tolist() + + selected_samples = [sorted_samples[i] for i in macro_selected_index] + sorted_selected_samples = sorted( + selected_samples, + key=lambda tracker: abs(tracker.reward_structure.performance_reward), + ) + return sorted_selected_samples + + def rollout_dynamic( # noqa: C901 + self, + tasks: List[Task], + mode: Literal["sample", "validate"], + epoch: str, + allow_sample_num_change=True, + allow_force_stop=True, + ) -> List[BaseContextTracker]: + """Perform oversampled rollouts with optional early termination heuristics.""" + + tracker_array: List[BaseContextTracker] = [] + assert mode != "validate" + rollout_n = self.rollout_n + self.current_token_count_time = time.time() + submit_oversample_multiplier = self.config.ajet.rollout.submit_oversample_multiplier + rollout_n_oversample = int(rollout_n * submit_oversample_multiplier) + rollout_n_confirm = int(rollout_n * (1 + submit_oversample_multiplier) / 2) + assert rollout_n < rollout_n_confirm < rollout_n_oversample, f"submit_oversample_multiplier is too small, rollout_n={rollout_n}, rollout_n_confirm={rollout_n_confirm}, rollout_n_oversample={rollout_n_oversample}" + + observation_window: Dict[str, List[int | bool]] = { + "step": [0 for _ in range(len(tasks) * rollout_n_oversample)], + "stop": [False for _ in range(len(tasks) * rollout_n_oversample)], + "token": [0 for _ in range(len(tasks) * rollout_n_oversample)], + } + + with ThreadPoolExecutor(max_workers=self.max_parallel) as executor: + futures = [] + for task_batch_index, task in enumerate(tasks): + task_future_array = [] + for task_rollout_index in range(rollout_n_oversample): + task_thread_index = task_batch_index * rollout_n_oversample + task_rollout_index + future = executor.submit( + self.rollout_env_worker, + task=task, + task_batch_index=task_batch_index, + task_tag=f"T{task.task_id}#R{task_rollout_index}", + mode=mode, + task_thread_index=task_thread_index, + observation_window=observation_window, + ) + task_future_array.append(future) + futures += [task_future_array] + + # A while loop to wait for all task can be terminated + tic = -1 + while True: + tic += 1 + can_terminate = [False for _ in futures] + terminate_status = ["running" for _ in futures] + for j, task_future_array in enumerate(futures): + completed_task_futures = [f for f in task_future_array if f.done()] + completed_results = [f.result() for f in completed_task_futures] + completed_results = [tracker for tracker in completed_results if not tracker.discarded] + reward = [tracker.reward_structure.performance_reward for tracker in completed_results] + reward_std = np.std(reward) if reward else 0.0 + all_finished = len(completed_task_futures) == len(task_future_array) + if all_finished: + can_terminate[j] = True + terminate_status[j] = f"all_fin({len(completed_results)}/{reward_std:.2f})" + num_finished = len(completed_task_futures) + task_cmd_reward_array = [tracker.reward_structure.performance_reward for tracker in completed_results] + all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) + if not all_equal: + if num_finished >= rollout_n: + can_terminate[j] = True + terminate_status[j] = f"early_end({len(completed_results)}/{reward_std:.2f})" + else: + pass + else: + if num_finished >= rollout_n_confirm: + can_terminate[j] = True + terminate_status[j] = f"confirm_dummy({len(completed_results)}/{reward_std:.2f})" + if allow_force_stop: + for k in range( + j * rollout_n_oversample, + j * rollout_n_oversample + rollout_n_oversample, + ): + observation_window["stop"][k] = True + else: + pass + terminate_status = "/".join(terminate_status) + if all(can_terminate): + logger.info(f"epoch{epoch}.collect_rollout: all tasks finished, exiting loop") + for i, stop_flag in enumerate(observation_window["stop"]): + observation_window["stop"][i] = True + break + else: + if tic % 10 == 0: + self.step_status_printer(observation_window) + logger.info(f"task complete {sum(can_terminate)}/{len(can_terminate)} tasks: {terminate_status}") + time.sleep(5) + + # We have enough number of samples, but we need to wait for all threads to finish, including discarded threads + tic = -1 + while any(f.running() for task_future_array in futures for f in task_future_array): + tic += 1 + if tic % 10 == 0: + logger.info("waiting final sync, this will not take long") + time.sleep(5) + + # find sample group that has identical reward, mark them as need_amend + task_ineffective_thread_cnt = [] + task_completed_thread_cnt = [] # how many effective threads are obtained per group + task_extra_thread_cnt = [] # using rollout_n as baseline, how many extra threads are obtained per group + task_need_amend = 0 # how many groups need amendment due to identical rewards + for j, task_future_array in enumerate(futures): + completed_task_futures = [f for f in task_future_array if f.done()] + completed_results = [f.result() for f in completed_task_futures] + completed_results = [tracker for tracker in completed_results if not tracker.discarded] + task_cmd_reward_array = [tracker.reward_structure.performance_reward for tracker in completed_results] + all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) + completed_task_cnt = len(completed_results) + if all_equal: + task_need_amend += 1 + task_completed_thread_cnt += [0] + task_extra_thread_cnt += [0] + task_ineffective_thread_cnt += [completed_task_cnt] + else: + task_need_amend += 0 + task_completed_thread_cnt += [completed_task_cnt] + task_extra_thread_cnt += [completed_task_cnt - rollout_n] + task_ineffective_thread_cnt += [0] + + logger.info(f"task_completed_thread_cnt: {task_completed_thread_cnt}") + logger.info(f"task_extra_thread_cnt: {task_extra_thread_cnt}") + + # reduce `task_extra_thread_cnt` + world_size = self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes + # the number of all reward-diverse samples + total_sample = sum(task_completed_thread_cnt) + + # begin to compute a removal plan (output: `task_extra_thread_cnt` and `num_task_to_amend`) + # - task_extra_thread_cnt: using rollout_n as baseline, how many extra threads to preserve per group + # - num_task_to_amend: how many groups can be amended according to removal plan + if allow_sample_num_change and (total_sample > world_size * 2): + # When changing the number of samples is ALLOWED + num_task_to_amend = len(futures) # this means infinate budget to amend, indicating that we throw away all ineffective samples + task_extra_thread_cnt = task_extra_thread_cnt # do not change extra thread cnt, we simply take all diverse samples + # log + logger.info(f"task_completed_thread_cnt (after remove): {task_completed_thread_cnt}") + logger.info(f"task_extra_thread_cnt (after remove): {task_extra_thread_cnt}") + else: + # When changing the number of samples is NOT ALLOWED (or the number of samples are too small) + # compute how many valid extra samples are obtained during previous oversampling + num_task_max_to_amend = sum(task_extra_thread_cnt) // rollout_n + # compute how many tasks actually need amendment, we fix as many as we can, but not exceed `num_task_max_to_amend`: + # - num_task_max_to_amend: how many CAN be fixed + # - task_need_amend: how many SHOULD be fixed + num_task_to_amend = min(num_task_max_to_amend, task_need_amend) + # according to `num_task_to_amend`, how many extra samples should be CONSUMED + extra_num_thread_required = num_task_to_amend * rollout_n + # after CONSUME, how many extra samples are really EXTRA and should be REMOVED + remove_count = sum(task_extra_thread_cnt) - extra_num_thread_required + logger.info(f"forbid_sample_num_change policy: num_task_max_to_amend: {num_task_max_to_amend}, " f"num_task_to_amend: {num_task_to_amend}, remove_count: {remove_count}, ") + # remove extra samples according to `remove_count` + while remove_count != 0: + # if we should remove some extra samples, we always remove from the group that has the MOST extra samples + max_extra_index = task_extra_thread_cnt.index(max(task_extra_thread_cnt)) + assert task_extra_thread_cnt[max_extra_index] > 0, "task_extra_thread_cnt should be greater than 0" + task_extra_thread_cnt[max_extra_index] -= 1 + task_completed_thread_cnt[max_extra_index] -= 1 + remove_count -= 1 + + # now, we have computed the final `task_extra_thread_cnt` and `num_task_to_amend`, which the removal plan deps + logger.info(f"task_completed_thread_cnt (after remove): {task_completed_thread_cnt}") + logger.info(f"task_extra_thread_cnt (after remove): {task_extra_thread_cnt}") + + # collect results and get the final tracker_array according to removal plan (`task_extra_thread_cnt` and `num_task_to_amend`) + tracker_array = [] + print_buffer = "" + task_success_rate = [] + task_group_reward = [] + for j, task_future_array, avail_extra_cnt in zip(range(len(futures)), futures, task_extra_thread_cnt): + completed_task_futures = [f for f in task_future_array if f.done()] + completed_results = [f.result() for f in completed_task_futures] + completed_results = [tracker for tracker in completed_results if not tracker.discarded] + # in-group success rate and reward + task_cmd_reward_array = [tracker.reward_structure.performance_reward for tracker in completed_results] + success_rate_array = [tracker.reward_structure.success_rate for tracker in completed_results] + task_group_reward += [np.mean([tracker.reward_structure.final_scalar_reward for tracker in completed_results])] + task_success_rate += [np.mean(success_rate_array)] + # whether this group need amendment + need_amend = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) + # if so, whether we have quota (num_task_to_amend) to amend + if need_amend and (num_task_to_amend > 0): + # this group need amendment, so, we discard all its samples + # do not worry, other groups will take up the slack + num_task_to_amend -= 1 + print_buffer += "/(amend)" + continue + else: + if need_amend: + # this group need amendment, but we simply do not have quota to amend + # we just accept rollout_n samples from this group + num_to_be_selected = rollout_n + else: + # this group is good and healthy, if it has extra samples, we accept them + num_to_be_selected = rollout_n + avail_extra_cnt + # if num_to_be_selected > the number of resulting samples, we choose them to maximum reward diversity + selected_tracker_array = self.greedy_max_std_selection(completed_results, num_to_be_selected) + # good, we have collected selected samples from this group + tracker_array += selected_tracker_array + # print info + print_buffer += f"/({len(selected_tracker_array)})" + if need_amend: + print_buffer += "(no-amend)" + + logger.info(print_buffer) + + # for tracker in tracker_array: + # # average of gourp success rate + # tracker.current_batch_success_rate = np.mean(task_success_rate) + # # average of gourp average reward + # tracker.current_batch_reward = np.mean(task_group_reward) + + return tracker_array + + +class VerlRolloutManager(DynamicRolloutManager): + """High-level manager orchestrating rollouts and batch conversion.""" + + def to_dataproto(self, tracker_array) -> DataProto: + """Convert completed context trackers into a `DataProto` minibatch.""" + samples = self.trajectories_to_samples(tracker_array) + dataproto = self.samples_to_dataproto(samples) + return dataproto + + def trajectories_to_samples(self, tracker_array: List[BaseContextTracker]) -> List[Sample]: + """Tokenize each tracker into `Sample` objects ready for tensorization.""" + sample_arr_final = [] + BaseContextTracker.compute_reference_advantage(tracker_array) + for tracker in tracker_array: + try: + sample_arr = tracker.group_tokenize() + except Exception as e: + raise e + finally: + tracker.generate_log(global_step=self.current_global_steps) + if os.environ.get("BEST_LOGGER_PATH", None) and os.environ.get("AJET_DEBUG", None): + logger.success(f"View rollout details at [http://localhost:8181/?path={quote(os.path.abspath(os.environ['BEST_LOGGER_PATH']))}]") + sample_arr_final += sample_arr + + if self.config.ajet.backbone in ["verl"]: + world_size = self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes + remainder = len(sample_arr_final) % world_size + if remainder != 0: + import random + + remove_indices = random.sample(range(len(sample_arr_final)), remainder) + remove_indices.sort(reverse=True) + for idx in remove_indices: + sample_arr_final.pop(idx) + + return sample_arr_final + + def samples_to_dataproto(self, samples: list[Sample]) -> DataProto: + """Pad sample fields and pack them into the `DataProto` structure expected by VERL.""" + prompt_ids: torch.Tensor | List[torch.Tensor] = [] + response_ids: torch.Tensor | List[torch.Tensor] = [] + prompt_attention_mask: torch.Tensor | List[torch.Tensor] = [] + response_attention_mask: torch.Tensor | List[torch.Tensor] = [] + prompt_position_ids: torch.Tensor | List[torch.Tensor] = [] + response_position_ids: torch.Tensor | List[torch.Tensor] = [] + prompt_loss_mask: torch.Tensor | List[torch.Tensor] = [] + response_loss_mask: torch.Tensor | List[torch.Tensor] = [] + + messages = [] + step_reward_scores = [] + task_ids = [] + rollout_ids = [] + reference_advantage = [] + + for sample in samples: + assert len(sample.input_ids) == len(sample.attention_mask) == len(sample.position_ids) == len(sample.loss_mask), f"Sample has mismatched lengths: {len(sample.input_ids)=}, {len(sample.attention_mask)=}, {len(sample.position_ids)=}, {len(sample.loss_mask)=}" + + task_ids.append(sample.task_id) + rollout_ids.append(sample.task_tag) + if len(sample.prompt_ids) > self.config.ajet.data.max_prompt_length: + raise RuntimeError(f"Sample has prompt_ids length {len(sample.prompt_ids)} ") + + if len(sample.response_ids) > self.config.ajet.data.max_response_length: + raise RuntimeError(f"Sample has prompt_ids length {len(sample.prompt_ids)} ") + + assert len(sample.prompt_ids) != 0 + assert len(sample.response_ids) != 0 + prompt_ids.append(torch.tensor(sample.prompt_ids, dtype=torch.int)) + response_ids.append(torch.tensor(sample.response_ids, dtype=torch.int)) + + prompt_attention_mask.append(torch.tensor(sample.prompt_attention_mask, dtype=torch.int)) + response_attention_mask.append(torch.tensor(sample.response_attention_mask, dtype=torch.int)) + + prompt_position_ids.append(torch.tensor(sample.prompt_position_ids, dtype=torch.int)) + response_position_ids.append(torch.tensor(sample.response_position_ids, dtype=torch.int)) + + prompt_loss_mask.append(torch.tensor(sample.prompt_loss_mask, dtype=torch.int)) + response_loss_mask.append(torch.tensor(sample.response_loss_mask, dtype=torch.int)) + + reference_advantage.append(sample.reference_advantage) + + messages.append({"messages": sample.messages}) + step_reward_scores.append(sample.step_reward) # append reward scalar + + max_prompt_length_this_batch = max([p.shape[-1] for p in prompt_ids]) + assert max_prompt_length_this_batch <= self.config.ajet.data.max_prompt_length + max_response_length_this_batch = max([p.shape[-1] for p in response_ids]) + assert max_response_length_this_batch <= self.config.ajet.data.max_response_length + + prompt_ids = pad_sequence( + prompt_ids, + batch_first=True, + padding_value=self.pad_token_id, + padding_side="left", + ) + prompt_attention_mask = pad_sequence( + prompt_attention_mask, + batch_first=True, + padding_value=0, + padding_side="left", + ) + prompt_position_ids = pad_sequence( + prompt_position_ids, + batch_first=True, + padding_value=0, + padding_side="left", + ) + prompt_loss_mask = pad_sequence( + prompt_loss_mask, + batch_first=True, + padding_value=0, + padding_side="left", + ) + + prompt_ids = pad_sequence_to_length( + prompt_ids, + max_prompt_length_this_batch, + self.pad_token_id, + left_pad=True, + ) + prompt_attention_mask = pad_sequence_to_length( + prompt_attention_mask, + max_prompt_length_this_batch, + 0, + left_pad=True, + ) + prompt_position_ids = pad_sequence_to_length(prompt_position_ids, max_prompt_length_this_batch, 0, left_pad=True) + prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, max_prompt_length_this_batch, 0, left_pad=True) + + response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) + response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) + response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) + + response_ids = pad_sequence_to_length(response_ids, max_response_length_this_batch, self.pad_token_id) + response_attention_mask = pad_sequence_to_length(response_attention_mask, max_response_length_this_batch, 0) + response_loss_mask = pad_sequence_to_length(response_loss_mask, max_response_length_this_batch, 0) + + delta_position_id = torch.arange(1, response_ids.size(1) + 1, device=response_ids.device).unsqueeze(0).repeat(len(samples), 1) + response_position_ids = prompt_position_ids[:, -1:] + delta_position_id + + input_ids = torch.cat((prompt_ids, response_ids), dim=-1) + attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) + position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) + loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1) + + batch = TensorDict( + { + "prompts": prompt_ids, + "responses": response_ids, + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + }, + batch_size=len(samples), + ) + + return DataProto( + batch=batch, + non_tensor_batch={ + "task_ids": np.array(task_ids), + "rollout_ids": np.array(rollout_ids), + "messages": np.array(messages), + "reward_scores": np.array(step_reward_scores), + "reference_advantage": np.array(reference_advantage), + }, + ) diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py new file mode 100644 index 00000000..47a7412f --- /dev/null +++ b/ajet/task_rollout/resource_keeper.py @@ -0,0 +1,197 @@ +from typing import Any, Dict, List, Tuple + +from loguru import logger +from omegaconf import DictConfig + +from ajet.schema.task import WorkflowTask +from ajet.utils.env_service_client.env_client_ng import ( + EnvClient as EnvClientNg, +) + + +class ResourceKeeper(object): + """ + TODO: integrate with A.S. Runtime + """ + + def __init__(self, workflow_task: WorkflowTask, config: DictConfig): + self.workflow_task = workflow_task + self.config = config + + def __enter__(self): + self.config = self.config + self.workflow_task = self.workflow_task + self.task_id: str = self.workflow_task.task_id + self.tokenizer = self.workflow_task.tokenizer + self.llm_inference_fn = self.workflow_task.llm_inference_fn + self.observation_window = self.workflow_task.observation_window + if self.config.ajet.task_reader.type == "env_service": + url = self.config.ajet.task_reader.env_service.env_url + env_type = self.config.ajet.task_reader.env_service.env_type + self.env = EnvClientNg(base_url=url) + self.env_params = {} + self.env_type: str = env_type + else: + self.env = None + return self + + def __exit__(self, exc_type, exc_value, traceback): + try: + if self.env: + self.env.release_instance(self.workflow_task.episode_uuid) + except Exception as e: + logger.bind(exception=True).exception(f"encounter exception in env_worker.release_instance~ error={e.args}") + raise e + + def prepare(self): + """ + Prepare the environment and initial messages for the workflow task. + + Returns: + WorkflowTask: The updated workflow task with initialized environment and messages. + """ + init_messages = self._initialize_environment_and_messages() + self.workflow_task.task.init_messages = init_messages + self.workflow_task.gym_env = self.generate_gym_env( + self.env, + self.workflow_task.episode_uuid, + self.workflow_task.task_thread_index, + self.workflow_task.observation_window, + ) + + return self.workflow_task + + def _initialize_environment_and_messages(self) -> List[dict]: + """ + Initialize environment instance and setup initial messages. + + Returns: + List[dict]: Initial messages for the agent flow + + Raises: + Exception: If environment creation fails or required task data is missing + """ + + if self.config.ajet.task_reader.type == "env_service": + if self.env is None: + raise ValueError("Environment client is None but env_service type is specified") + try: + init_response = self.env.create_instance( + env_type=self.env_type, + task_id=self.task_id, + instance_id=self.workflow_task.episode_uuid, + params=self.env_params, + ) + state_message: dict = init_response["state"] + query, init_messages = self._get_init_messages(state_message) + # Update main_query with actual query from environment + self.workflow_task.task.main_query = query + except Exception as e: + logger.bind(exception=True).exception(f"encounter exception in env_worker.create_instance~ error={e.args}") + if self.env is not None: + self.env.release_instance(self.workflow_task.episode_uuid) + raise e + else: + task = self.workflow_task.task + if task.init_messages: + init_messages = task.init_messages + else: + assert task.main_query, "You must provide init_messages or main_query in task." + init_messages = [{"role": "user", "content": task.main_query}] + + return init_messages + + def _get_init_messages(self, state_message) -> tuple: + """ + Process state_message to extract query and init_messages. + + Args: + state_message (Union[dict, list]): The state message to process + + Returns: + tuple: (query, init_messages) where query is a string and init_messages is a list + + Raises: + ValueError: If state_message is neither dict nor list + """ + if isinstance(state_message, dict): + query = state_message["content"] + init_messages = [state_message] + elif isinstance(state_message, list): + assert isinstance(state_message[0], dict) + query = state_message[-1]["content"] + init_messages = state_message + else: + raise ValueError(f"state_message should be dict or list, but got {type(state_message)}") + + return query, init_messages + + def generate_gym_env(self, env_client: Any, episode_uuid: str, task_thread_index: int, observation_window: Dict) -> "BaseGymEnv": + return BaseGymEnv(env_client, episode_uuid, task_thread_index, observation_window) + + +class BaseGymEnv(object): + """ + TODO: integrate with A.S. Runtime + """ + + def __init__( + self, + env_client: EnvClientNg, + episode_uuid: str, + task_thread_index: int, + observation_window: Dict, + ): + self.env_client = env_client + self.task_thread_index = task_thread_index + self.observation_window = observation_window + self.episode_uuid = episode_uuid + if self.env_client: + self.service_url = self.env_client.base_url + + def step(self, action: dict) -> Tuple[str, float, bool, dict]: + """Take a step in the gym environment.""" + if not isinstance(action["content"], str): + # assert isinstance(action['content'], list) + # assert len(action['content']) == 1 + # assert isinstance(action['content'][0], dict) + # assert 'type' in action['content'][0] + # assert 'text' in action['content'][0] + try: + action["content"] = action["content"][0]["text"] + except Exception: + logger.exception(f"Failed to parse action content from agentscope output. {action['content']}") + action["content"] = str(action["content"]) + + self.observation_window["step"][self.task_thread_index] += 1 + env_output = self.env_client.step( + instance_id=self.episode_uuid, + action=action, + ) + obs = "" + assert isinstance(env_output, dict) + + if isinstance(env_output["state"], list): + # 1. If state is a list (new standard format), pass through directly + obs = env_output["state"] + else: + # 2. If state is a dict (old format or error) + if ("content" not in env_output["state"]) and ("error" in env_output["state"]): + obs = f"[Error from environment: {env_output['error']}]" + elif env_output["state"].get("content", "") == "": + obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again." + else: + obs = env_output["state"]["content"] + + reward = 0 + info = {} + terminate = env_output["is_terminated"] + return obs, reward, terminate, info # type: ignore + + def reset(self) -> str: + """Reset gym environment.""" + raise RuntimeError("Reset is not supported") + + def evaluate(self, episode_uuid, params): + """Evaluate and get reward.""" + return self.env_client.evaluate(episode_uuid, params) diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py new file mode 100644 index 00000000..e8865ea9 --- /dev/null +++ b/ajet/task_rollout/single_worker.py @@ -0,0 +1,123 @@ +"""Single worker primitives for environment rollouts.""" + +import uuid +from typing import Literal + +from loguru import logger +from omegaconf import DictConfig +from transformers.tokenization_utils import PreTrainedTokenizer + +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.schema.task import Task, WorkflowTask +from ajet.task_rollout.async_llm_bridge import AsyncLlmBridge +from ajet.task_rollout.resource_keeper import ResourceKeeper +from ajet.task_runner.general_runner import GeneralRunner +from ajet.utils.retry import retry_with_backoff +from ajet.utils.sample import get_sample_params +from ajet.utils.testing_utils import TestFailException, TestSuccessException + + +class BaseRolloutManager: + def __init__( + self, + config: DictConfig, + async_rollout_manager, + max_parallel: int, + max_llm_retries: int = 3, + tokenizer: PreTrainedTokenizer = None, # type: ignore + llm_mode: Literal["local", "remote", "trinity"] = "local", + **kwargs, + ): + """Initialize common rollout state and helpers. + + Parameters + ---------- + config : DictConfig + Configuration object containing rollout and experiment settings. + async_rollout_manager : Any + Manager responsible for async LLM interactions. + max_parallel : int + Maximum number of parallel environment worker threads. + max_llm_retries : int, optional + Maximum retries for LLM calls, by default 3. + tokenizer : PreTrainedTokenizer, optional + Tokenizer used for padding and ID conversions. + llm_mode : Literal["local", "remote", "trinity"], optional + Indicates backend mode (e.g., 'local', 'remote'), default 'local'. + **kwargs : Any + Additional parameters passed through for future extensions. + """ + + self.llm_mode: Literal["local", "remote", "trinity"] = llm_mode + self.config: DictConfig = config + self.async_rollout_manager = async_rollout_manager + self.max_parallel: int = max_parallel + self.max_llm_retries: int = max_llm_retries + self.rollout_n = config.ajet.rollout.num_repeat + self.tokenizer = tokenizer + self.pad_token_id: int = self.tokenizer.pad_token_id # type: ignore + assert isinstance(self.pad_token_id, int), "pad_token_id must be an integer" + self.current_token = 0 + self.current_global_steps: int | str = "NA" + self.async_llm_bridge = AsyncLlmBridge( + config=config, + async_rollout_manager=async_rollout_manager, + tokenizer=tokenizer, + llm_mode=llm_mode, + max_llm_retries=max_llm_retries, + ) + + @retry_with_backoff(max_retry_attr="max_llm_retries") + def rollout_env_worker( + self, + task: Task, + task_batch_index: int, + task_tag: str, + mode: Literal["sample", "validate"], + task_thread_index: int, + observation_window: dict, + **kwargs, + ) -> BaseContextTracker: + """Execute one environment rollout worker. + + Handles environment initialization, LLM sampling parameter construction + (with validation overrides), and robust retry on transient failures. + """ + sampling_params = get_sample_params(mode, self.config) + + if self.config.ajet.task_runner.llm_infer_submit_method == "sync": + llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_sync(sampling_params=sampling_params) + else: + llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async(sampling_params=sampling_params) + + workflow_task = WorkflowTask( + env_type=task.env_type, + task_id=task.task_id, + task_thread_index=task_thread_index, + task_batch_index=task_batch_index, + episode_uuid=uuid.uuid4().hex, + task_tag=task_tag, + observation_window=observation_window, + llm_inference_fn=llm_inference_fn, + tokenizer=self.tokenizer, + task=task, + ) + + with ResourceKeeper(workflow_task, config=self.config) as resource_keeper: + try: + workflow_task = resource_keeper.prepare() + agent_runner = GeneralRunner(llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config) + tracker = agent_runner.execute( + workflow_task=workflow_task, + ) + except TestSuccessException as e: + logger.success(f"env_worker.agent_flow completed with TestSuccessException: {e.args}") + raise e + except TestFailException as e: + logger.error(f"env_worker.agent_flow failed with TestFailException: {e.args}") + raise e + except Exception as e: + logger.bind(exception=True).exception(f"encounter exception in env_worker.agent_flow error={e.args}") + raise e + + return tracker diff --git a/ajet/task_runner/__init__.py b/ajet/task_runner/__init__.py new file mode 100644 index 00000000..861652de --- /dev/null +++ b/ajet/task_runner/__init__.py @@ -0,0 +1,4 @@ +from .general_runner import GeneralRunner +from .base_runner import BaseAgentRunner + +__all__ = ["BaseAgentRunner", "GeneralRunner"] diff --git a/ajet/task_runner/base_runner.py b/ajet/task_runner/base_runner.py new file mode 100644 index 00000000..3028a829 --- /dev/null +++ b/ajet/task_runner/base_runner.py @@ -0,0 +1,111 @@ +import asyncio +import gc +from threading import Lock +from typing import Any, Callable, Union, Type +from multiprocessing import Process, Queue +from unittest import result + +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.schema.task import WorkflowOutput, WorkflowTask +from ajet.task_judge.base_judge import BaseJudge +from ajet.tuner import AjetTuner +from ajet.utils.async_utils import run_async_coroutine_with_timeout +from ajet.utils.dynamic_import import dynamic_import +from ajet.workflow import Workflow + +gc_lock = Lock() + + +class BaseAgentRunner(object): + def __init__(self, llm_inference_fn: Callable, tokenizer: Any, config, **kwargs): + self.tokenizer = tokenizer + self.instruction_template_ids = self.tokenizer.encode("<|im_start|>user\n") + self.response_template_ids = self.tokenizer.encode("<|im_start|>assistant\n") + self.tracker: Union[BaseContextTracker, Any, None] = None + self.external_llm_fn: Union[Callable, None] = None + self.llm_inference_fn: Callable = llm_inference_fn + self.config = config + self.max_steps: int = self.config.ajet.rollout.multi_turn.max_steps + self.max_model_len: int = self.config.ajet.rollout.max_model_len + + self.wrapper_type = self.config.ajet.task_runner.wrapper_type + self.wrapper_multiprocessing_timeout = self.config.ajet.task_runner.wrapper_multiprocessing_timeout + assert self.wrapper_type in ["asyncio", "asyncio-with-gc", "multi-processing"], f"Unsupported wrapper type: {self.wrapper_type}, available options: ['asyncio', 'asyncio-with-gc', 'multi-processing']" + + def get_judge(self) -> BaseJudge: # type: ignore + if self.config.ajet.task_judge.judge_type == "customized_protocol": + judge_protocol = self.config.ajet.task_judge.judge_protocol + return dynamic_import(judge_protocol)(self.config) # type: ignore + + elif self.config.ajet.task_judge.judge_type == "rubrics_auto_grader": + # ajet/task_judge/rm_auto_grader_judge.py + from ajet.task_judge.rm_auto_grader_judge import AutoGraderJudge + + judge = AutoGraderJudge(self.config) + run_async_coroutine_with_timeout(judge.load_rubrics_from_cache()) + return judge + + def runner_hooks(self, observation_window, task_thread_index, workflow_task): + def should_interrupt_fn() -> bool: + if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) + return True + return False + + def generated_token_callback_fn(token_array): + observation_window["token"][task_thread_index] += len(token_array) + + return { + "should_interrupt_fn": should_interrupt_fn, + "generated_token_callback_fn": generated_token_callback_fn, + } + + async def wrapper_type_asyncio(self, workflow_cls: Type[Workflow], workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + user_workflow: Workflow = workflow_cls(name="ajet-workflow") + result = await user_workflow.execute(workflow_task, tuner) + + # malloc garbage collection + del user_workflow + + # run gc in a thread-safe way + if gc_lock.acquire(blocking=False): + try: + gc.collect() + finally: + gc_lock.release() + return result + + def wrapper_type_multiprocessing(self, workflow_cls: Type[Workflow], workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + def worker(q: Queue): + user_workflow: Workflow = workflow_cls(name="ajet-workflow") + result = asyncio.run(user_workflow.execute(workflow_task, tuner)) + q.put(result) + + q = Queue() + p = Process(target=worker, args=(q,)) + p.daemon = True + p.start() + p.join(timeout=self.wrapper_multiprocessing_timeout) + if p.is_alive(): + p.terminate() + p.join() + raise TimeoutError(f"Workflow execution timeout after {self.wrapper_multiprocessing_timeout} seconds") + return q.get() + + def run_user_workflow( + self, + workflow_cls: Type[Workflow], + workflow_task: WorkflowTask, + tuner: AjetTuner, + ) -> WorkflowOutput: + if self.wrapper_type == "asyncio": + user_workflow: Workflow = workflow_cls(name="ajet-workflow") + return asyncio.run(user_workflow.execute(workflow_task, tuner)) + + if self.wrapper_type == "asyncio-with-gc": + return asyncio.run(self.wrapper_type_asyncio(workflow_cls, workflow_task, tuner)) + + elif self.wrapper_type == "multi-processing": + return self.wrapper_type_multiprocessing(workflow_cls, workflow_task, tuner) + + else: + raise ValueError(f"Unsupported wrapper type: {self.wrapper_type}") diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py new file mode 100644 index 00000000..7d527340 --- /dev/null +++ b/ajet/task_runner/general_runner.py @@ -0,0 +1,80 @@ +from venv import logger + +from ajet import AjetTuner +from ajet import Workflow, WorkflowOutput +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) +from ajet.context_tracker.basic_tracker import BaseContextTracker +from ajet.schema.task import WorkflowTask +from ajet.schema.trajectory import Reward +from ajet.task_runner.base_runner import BaseAgentRunner +from ajet.utils.dynamic_import import dynamic_import + + +class GeneralRunner(BaseAgentRunner): + def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: + observation_window = workflow_task.observation_window + task_thread_index = workflow_task.task_thread_index + + workflow_import = self.config.ajet.rollout.user_workflow + workflow_cls = dynamic_import(workflow_import) + + hooks = self.runner_hooks( + observation_window=observation_window, + task_thread_index=task_thread_index, + workflow_task=workflow_task, + ) + context_tracker = MultiAgentContextTracker( + llm_inference_fn=self.llm_inference_fn, + tokenizer=self.tokenizer, + config=self.config, + workflow_task=workflow_task, + **hooks, + ) + tuner = AjetTuner( + context_tracker=context_tracker, + llm_inference_fn=self.llm_inference_fn, + workflow_cls=workflow_cls, + config=self.config, + ) + + # run workflow + # user_workflow: Workflow = workflow_cls(name="ajet-workflow") + workflow_output: WorkflowOutput = self.run_user_workflow( + workflow_cls, + workflow_task, + tuner, + ) + + if workflow_output.reward is not None: + raw_reward, is_success = ( + workflow_output.reward, + workflow_output.is_success, + ) + else: + raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output) + + workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue + + assert not isinstance(raw_reward, list), "AgentJet will support step reward in future versions." + + # register reward + # TODO: support multi-step reward + reward = Reward( + raw_reward=raw_reward, + raw_step_reward=None, # "AgentJet will support step reward in future versions." + success_rate=1.0 if is_success else 0.0, + madness=0, + description="", + ) + context_tracker.process_reward(reward) + # generate token before merging + context_tracker.group_merge() + # after merging, process and align reward again + context_tracker.process_reward(reward) + # mark the thread as ended + observation_window["step"][task_thread_index] = -1 + tuner.terminate_episode() + context_tracker.log_metrics = workflow_output.log_metrics + return context_tracker diff --git a/ajet/tuner.py b/ajet/tuner.py new file mode 100644 index 00000000..aefff119 --- /dev/null +++ b/ajet/tuner.py @@ -0,0 +1,178 @@ +from typing import TYPE_CHECKING, Any, Literal, Callable, Union, Type + +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) + +from ajet.tuner_lib.weight_tuner import AgentScopeModelTuner +from ajet.tuner_lib.weight_tuner import OpenaiClientModelTuner +from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiClientBaseUrlTuner + +if TYPE_CHECKING: + from ajet import Workflow + +TunerTypeUnion = Union[AgentScopeModelTuner, OpenaiClientModelTuner] + + +class AjetTuner(object): + def __init__( + self, + config, + context_tracker: MultiAgentContextTracker, + workflow_cls: Type["Workflow"], + llm_inference_fn: Callable, + ) -> None: + self.config = config + self.trainable_targets = self._get_trainable_targets(workflow_cls) + self.context_tracker = context_tracker + self.llm_inference_fn = llm_inference_fn + self.target2proxy_registry: dict[str, dict[str, TunerTypeUnion]] = {} + self.enable_interchange_server = config.ajet.enable_experimental_interchange_server + if self.enable_interchange_server: + self.proxy_client_started = False + + def _get_trainable_targets(self, workflow_cls: Type["Workflow"]): + workflow_instance = workflow_cls(name="ajet-workflow") + return workflow_instance.trainable_targets + + def as_agentscope_model(self, agent_name="default_agent_name", target_tag="default_target_tag", debug_model=None) -> "AgentScopeModelTuner": + """Convert to ModelTuner instance for Agentscope workflow. + Returns: + ModelTuner: + The ModelTuner instance for Agentscope workflow. + """ + explicit_tuner_as_modelscope_model = AgentScopeModelTuner( + config=self.config, + context_tracker=self.context_tracker, + agent_name=agent_name, + debug_model=debug_model, + use_debug_model=(not self._is_target_trainable(target_tag)), + llm_inference_fn=self.llm_inference_fn, + ) + self._register(target_tag, agent_name, explicit_tuner_as_modelscope_model) + return explicit_tuner_as_modelscope_model + + def as_raw_openai_sdk_client( + self, + agent_name="default_agent_name", + target_tag="default_target_tag", + debug_model="gpt-4o", + ) -> OpenaiClientModelTuner: + """Convert to raw OpenAI SDK client for advanced usage. + Returns: + Any: + The raw OpenAI SDK client. + """ + explicit_tuner_as_oai_client = OpenaiClientModelTuner( + config=self.config, + context_tracker=self.context_tracker, + agent_name=agent_name, + debug_model=debug_model, + use_debug_model=(not self._is_target_trainable(target_tag)), + llm_inference_fn=self.llm_inference_fn, + ) + self._register(target_tag, agent_name, explicit_tuner_as_oai_client) + return explicit_tuner_as_oai_client + + def as_oai_baseurl_apikey( + self, + agent_name="default_agent_name", + target_tag="default_target_tag", + ): + """ + Usage: + ```python + result = tuner.as_oai_baseurl_apikey() + + # take base_url, api_key, model_name + base_url = result.base_url + api_key = result.api_key + + # use base_url, api_key, model_name + client = AsyncOpenAI(base_url=base_url, api_key=api_key) + response = await client.chat.completions.create( + model='whatever_model_name_you_like', + messages=messages, + ) + ``` + """ + + assert self.enable_interchange_server, "Please enable `ajet.enable_experimental_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature." + if self.proxy_client_started is False: + self.proxy_client_started = True + self._enable_experimental_interchange_server(self.llm_inference_fn) + baseurl_apikey_model = OpenaiClientBaseUrlTuner( + config=self.config, + context_tracker=self.context_tracker, + agent_name=agent_name, + target_tag=target_tag, + episode_uuid=self.context_tracker.episode_uuid, + episode_contect_address=self.interchange_client.episode_contect_address, + ) + return baseurl_apikey_model + + def __call__(self, **kwargs): + """This method is **deprecated**. + The current behavior of this method is pretend as a agentscope model + """ + raise RuntimeError("This method is deprecated. Please use `as_agentscope_model` / `as_raw_openai_sdk_client` first.") + + # ------------------------------------------------------------------------ + # other helper methods + # ------------------------------------------------------------------------ + + def _register(self, target_name: str, agent_name: str, explicit_tuner: TunerTypeUnion) -> TunerTypeUnion: + """Register an agent type. + Args: + target_name (`str`): + The name to register the agent type under. + default_model (`ChatModelBase`): + The model to use when you are NOT training this agent type. + Returns: + Agent2Proxy: + The agent type instance corresponding to the provided name. + """ + if target_name not in self.target2proxy_registry: + self.target2proxy_registry[target_name] = {} + self.target2proxy_registry[target_name][agent_name] = explicit_tuner + return explicit_tuner + + def _is_target_trainable(self, target_name) -> bool: + """Determine whether user have used `trainable_targets` to explicitly control training targets.""" + if self.trainable_targets is None: + # always assume trainable when user has never changed trainable_targets + return True + if not self.trainable_targets: + # always assume trainable when trainable_targets is [] + return True + if target_name in self.trainable_targets: + return True + else: + return False + + def get_context_tracker(self) -> MultiAgentContextTracker: + """Get the context tracker instance. + Returns: + LlmProxyForAgentScope: + The context tracker instance used by the ModelTuner. + """ + return self.context_tracker + + def _enable_experimental_interchange_server(self, llm_inference_fn): + # experimental reverse proxy start + if self.enable_interchange_server: + from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_client import InterchangeClient + + self.interchange_client = InterchangeClient( + episode_uuid=self.context_tracker.episode_uuid, + context_tracker=self.context_tracker, + config=self.config, + llm_inference_fn=llm_inference_fn, + ) + return self.interchange_client.begin_service() + + def terminate_episode(self): + # experimental reverse proxy cleanup + if self.enable_interchange_server: + if (self.proxy_client_started is True) and hasattr(self, "interchange_client"): + self.interchange_client._should_terminate = True diff --git a/ajet/tuner_lib/weight_tuner/__init__.py b/ajet/tuner_lib/weight_tuner/__init__.py new file mode 100644 index 00000000..c50d6787 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/__init__.py @@ -0,0 +1,2 @@ +from ajet.tuner_lib.weight_tuner.as_agentscope_model import AgentScopeModelTuner +from ajet.tuner_lib.weight_tuner.as_oai_sdk_model import OpenaiClientModelTuner diff --git a/ajet/tuner_lib/weight_tuner/as_agentscope_model.py b/ajet/tuner_lib/weight_tuner/as_agentscope_model.py new file mode 100644 index 00000000..d125101a --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/as_agentscope_model.py @@ -0,0 +1,108 @@ +from typing import TYPE_CHECKING, Any, Literal, Type + +from agentscope._utils._common import _create_tool_from_base_model +from agentscope.model import ChatModelBase, ChatResponse, DashScopeChatModel +from loguru import logger +from pydantic import BaseModel + +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) +from ajet.task_rollout.async_llm_bridge import AgentScopeLlmProxyWithTracker + +if TYPE_CHECKING: + from ajet import Workflow + + +class AgentScopeModelTuner(DashScopeChatModel): + """ + ModelTuner for Agentscope workflow. + It keeps record of all registered agent types (by their target names), + And when request comes, it calls `self.llm_proxy` to handle the request. + """ + + def __init__( + self, + config, + context_tracker: MultiAgentContextTracker, + agent_name: str, + debug_model: DashScopeChatModel | None, + use_debug_model: bool = False, + llm_inference_fn=None, + ) -> None: + self.config = config + self.context_tracker = context_tracker + + self.agent_name = agent_name + self.debug_model = debug_model + self.use_debug_model = use_debug_model + self.llm_proxy = AgentScopeLlmProxyWithTracker(context_tracker=context_tracker, config=config, llm_inference_fn=llm_inference_fn) + super().__init__( + model_name="ajet", + api_key="dummy-api-key", + stream=False, + ) + + async def __call__( + self, + messages: list[dict[str, Any]], + tools: list[dict] | None = None, + tool_choice: Literal["auto", "none", "any", "required"] | str | None = None, + structured_model: Type[BaseModel] | None = None, + **kwargs: Any, + ) -> ChatResponse: + # route first + if self.use_debug_model and self.debug_model is not None: + chatresponse = await self.debug_model(messages, tools, tool_choice, structured_model, **kwargs) + assert isinstance(chatresponse, ChatResponse) + return chatresponse + + # For qvq and qwen-vl models, the content field cannot be `None` or + # `[{"text": None}]`, so we need to convert it to an empty list. + if self.model_name.startswith("qvq") or "-vl" in self.model_name: + raise NotImplementedError("Not implemented for qvq and qwen-vl models yet.") + + kwargs = { + "messages": messages, + "model": self.model_name, + "stream": self.stream, + **self.generate_kwargs, + **kwargs, + "result_format": "message", + # In agentscope, the `incremental_output` must be `True` when + # `self.stream` is True + "incremental_output": self.stream, + } + + if tools: + kwargs["tools"] = self._format_tools_json_schemas(tools) + + if tool_choice: + self._validate_tool_choice(tool_choice, tools) + kwargs["tool_choice"] = self._format_tool_choice(tool_choice) + + if self.enable_thinking is not None and "enable_thinking" not in kwargs: + kwargs["enable_thinking"] = self.enable_thinking + + if structured_model: + if tools or tool_choice: + logger.warning( + "structured_model is provided. Both 'tools' and " "'tool_choice' parameters will be overridden and " "ignored. The model will only perform structured output " "generation without calling any other tools.", + ) + format_tool = _create_tool_from_base_model(structured_model) + kwargs["tools"] = self._format_tools_json_schemas( + [format_tool], + ) + kwargs["tool_choice"] = self._format_tool_choice( + format_tool["function"]["name"], + ) + + # call llm model ✨ + response_gen = await self.llm_proxy( + api_key=self.api_key, + structured_model=structured_model, + **kwargs, + ) + + # Return the AsyncGenerator directly + return response_gen diff --git a/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py new file mode 100644 index 00000000..7464cfe4 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py @@ -0,0 +1,71 @@ +import os +import asyncio +from typing import TYPE_CHECKING, Any, List, Callable, Literal, Type, Union +from loguru import logger +from pydantic import BaseModel, Field +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) +from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker +from ajet.utils.magic_mock import SpecialMagicMock +from openai.types.chat.chat_completion import ChatCompletion +from openai.resources.chat.chat import Chat, AsyncChat +from openai.resources.completions import AsyncCompletions +from openai import OpenAI, AsyncOpenAI +from ajet.utils.networking import find_free_port +from .experimental.as_oai_model_client import generate_auth_token + +if TYPE_CHECKING: + from ajet import Workflow + + +class MockAsyncCompletions(AsyncCompletions): + async def create(self, *args, **kwargs) -> Any: # type: ignore + return await self._client.create(*args, **kwargs) # type: ignore + + +class MockAsyncChat(AsyncChat): + @property + def completions(self) -> MockAsyncCompletions: # type: ignore + return MockAsyncCompletions(self._client) + + +class OpenaiClientBaseUrlTuner(BaseModel): + """At this layer, we will determine which model to use: + - training model + - debug model assigned by user, used when this target is not being trained + """ + + base_url: str = Field(default="http://localhost:27788/v1", description="The base URL for the Ajet's fake OpenAI API") + api_key: str = Field(default="invalid_apikey", description="The Ajet's fake key, which is not a real key, it is a encoded string contain episode_uuid and other stuff.") + model: str = Field(default="reserved_field", description="reserved field.") + + def __init__( + self, + config, + context_tracker: MultiAgentContextTracker, + target_tag: str, + agent_name: str, + episode_uuid: str, + episode_contect_address: str, + **kwargs, + ): + port = os.getenv("AJET_DAT_INTERCHANGE_PORT") + assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + + base_url = f"http://{master_node_ip}:{port}/v1" + api_key = generate_auth_token( + agent_name=agent_name, + target_tag=target_tag, + episode_uuid=episode_uuid, + episode_address=episode_contect_address, + ) + model = "reserved_field" + + # Properly initialize the Pydantic BaseModel + super().__init__( + base_url=base_url, + api_key=api_key, + model=model, + ) diff --git a/ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py b/ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py new file mode 100644 index 00000000..8c00e651 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/as_oai_sdk_model.py @@ -0,0 +1,77 @@ +import asyncio +from typing import TYPE_CHECKING, Any, List, Callable, Literal, Type, Union +from loguru import logger +from pydantic import BaseModel +from ajet.context_tracker.multiagent_tracking import ( + MultiAgentContextTracker, +) +from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker +from ajet.utils.magic_mock import SpecialMagicMock +from openai.types.chat.chat_completion import ChatCompletion +from openai.resources.chat.chat import Chat, AsyncChat +from openai.resources.completions import AsyncCompletions +from openai import OpenAI, AsyncOpenAI + +if TYPE_CHECKING: + from ajet import Workflow + + +class MockAsyncCompletions(AsyncCompletions): + async def create(self, *args, **kwargs) -> Any: # type: ignore + return await self._client.create(*args, **kwargs) # type: ignore + + +class MockAsyncChat(AsyncChat): + @property + def completions(self) -> MockAsyncCompletions: # type: ignore + return MockAsyncCompletions(self._client) + + +class OpenaiClientModelTuner(AsyncOpenAI): + """At this layer, we will determine which model to use: + - training model + - debug model assigned by user, used when this target is not being trained + """ + + def __init__( + self, + config, + context_tracker: MultiAgentContextTracker, + agent_name: str, + debug_model: str | None = None, + use_debug_model: bool = False, + llm_inference_fn: Callable | None = None, + ): + self.debug_model = debug_model + self.agent_name = agent_name + self.use_debug_model = use_debug_model + assert llm_inference_fn is not None, "llm_inference_fn must be provided" + self.llm_proxy = OpenaiLlmProxyWithTracker( + context_tracker=context_tracker, + config=config, + llm_inference_fn=llm_inference_fn, + ) + + @property + def chat(self) -> MockAsyncChat: # type: ignore + return MockAsyncChat(self) + + async def create(self, messages: List[dict], tools: List = [], tool_choice: str = "auto", *args, **kwargs) -> ChatCompletion: + # route first + if self.use_debug_model and (self.debug_model is not None): + client = AsyncOpenAI() + return await client.chat.completions.create( + model=self.debug_model, + messages=messages, # type: ignore + tools=tools, + tool_choice=tool_choice, # type: ignore + ) + + # call llm model ✨ + response_gen = await self.llm_proxy( + messages=messages, + tools=tools, + tool_choice=tool_choice, + ) + assert isinstance(response_gen, ChatCompletion) + return response_gen diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py new file mode 100644 index 00000000..4daa6be3 --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py @@ -0,0 +1,207 @@ +import asyncio +import atexit +import json +import os +import time +import zmq +import base64 +import json + +from loguru import logger +from typing import TYPE_CHECKING +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from openai.types.chat.chat_completion import ChatCompletion +from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest, API_KEY_PREFIX +from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor +from ajet.utils.networking import find_free_port + + +context = zmq.Context() +atexit.register(context.term) + +if TYPE_CHECKING: + from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker + +DEBUG = False +# DEBUG = True + + +def generate_auth_token(agent_name, target_tag, episode_uuid, episode_address): + """ + Generate a Base64-encoded auth_token from the given agent_name, target_tag, and episode_uuid. + + Args: + agent_name (str): The name of the agent. + target_tag (str): The target tag. + episode_uuid (str): The UUID of the episode. + + Returns: + str: The generated auth_token in the format "Bearer ". + """ + # Step 1: Construct the auth_data dictionary + auth_data = { + "agent_name": agent_name, + "target_tag": target_tag, + "episode_uuid": episode_uuid, + "episode_address": episode_address, + } + + # Step 2: Convert the dictionary to a JSON string + json_string = json.dumps(auth_data) + + # Step 3: Encode the JSON string into Base64 + base64_encoded = base64.b64encode(json_string.encode("utf-8")).decode("utf-8") + + # Step 4: Prepend "Bearer " to the Base64-encoded string + auth_token = f"{API_KEY_PREFIX}{base64_encoded}" # API_KEY_PREFIX: Literal['sk-ajet-'] + + return auth_token + + +class InterchangeClient: + """InterchangeClient is re-created in each episode""" + + def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker", llm_inference_fn, config): + self.episode_uuid = episode_uuid + self.context_tracker = context_tracker + self.llm_inference_fn = llm_inference_fn + self.config = config + self._should_terminate = False + + self.interchange_method = config.ajet.interchange_server.interchange_method + if self.interchange_method == "tcp": + master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") + self.episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}" + elif self.interchange_method == "ipc": + self.ipc_path = f"/tmp/ajet/{self.episode_uuid}.sock" + self.episode_contect_address = f"ipc://{self.ipc_path}" + self.max_inference_tracker_threads = config.ajet.interchange_server.max_inference_tracker_threads + + async def llm_infer( + self, + req: ChatCompletionRequest, + timeline_uuid: str, + agent_name: str, + target_tag: str, + episode_uuid: str, + ) -> ChatCompletion: + from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker + + req_as_dict = req.model_dump() + self.llm_proxy_with_tracker = OpenaiLlmProxyWithTracker( + context_tracker=self.context_tracker, + config=self.config, + llm_inference_fn=self.llm_inference_fn, + ) + + # infer + process with context tracker + response = await self.llm_proxy_with_tracker( + messages=req_as_dict["messages"], + tools=req_as_dict["tools"], + tool_choice="auto", + ) + + # this is an important id assignment + response.id = timeline_uuid + assert isinstance(response, ChatCompletion) + return response + + @property + def should_terminate(self) -> bool: + return self._should_terminate + + def begin_service(self): + """ + Starts the zmq communication loop. + """ + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...") + self.socket = context.socket(zmq.REP) + self.socket.bind(f"{self.episode_contect_address}") + self.socket.setsockopt(zmq.RCVTIMEO, 3 * 1000) # 3 second timeout for REP + + self.executor = SharedInterchangeThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...") + future = self.executor.submit(self._begin_service_threading) + + # wait till service begin running + time.sleep(0.5) + w_time = 1 + while future._state == "PENDING": + time.sleep(min(w_time * 2, 10)) + w_time += 1 + + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | Future ready...") + return self.episode_contect_address + + def _begin_service_threading(self): + """begin listening for service requests in a threading model""" + + begin_time = time.time() + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | Starting ZMQ socket bind complete") + + try: + while not self.should_terminate: + # listen for next request from remote + try: + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun") + message = self.socket.recv_string() + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") + except zmq.Again as e: + if self.should_terminate: + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | episode over") + break + timepassed = time.time() - begin_time + if timepassed > 60: + logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...") + continue + + # parse the incoming request + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | before json.loads(message)") + data_as_json = json.loads(message) + parsed_msg = InterchangeCompletionRequest(**data_as_json) + + # begin to run the llm request, monitored by context tracker + # we re-use previously created thread for best performance + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | before asyncio run self.llm_infer") + try: + loop = asyncio.get_running_loop() + except: + loop = asyncio.new_event_loop() + context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() + future = loop.run_in_executor( + context_tracker_executor, + asyncio.run, + self.llm_infer( + req=parsed_msg.completion_request, + timeline_uuid=parsed_msg.timeline_uuid, + agent_name=parsed_msg.agent_name, + target_tag=parsed_msg.target_tag, + episode_uuid=parsed_msg.episode_uuid, + ), + ) + result = loop.run_until_complete(future).model_dump_json() # type: ignore + + # great, let's send back the result + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | before send_string") + self.socket.send_string(result) + except: + logger.exception(f"[client] {self.episode_uuid} | Exception occurred in service loop.") + finally: + self.socket.close() + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | ZMQ socket closed, service loop terminated.") + if self.interchange_method == "ipc": + if os.path.exists(self.ipc_path): + os.remove(self.ipc_path) + if DEBUG: + logger.info(f"[client] {self.episode_uuid} | IPC socket file {self.ipc_path} removed.") diff --git a/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py new file mode 100644 index 00000000..cb5054ee --- /dev/null +++ b/ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py @@ -0,0 +1,245 @@ +""" +A shadow FastAPI server for serving as interchange endpoint between Tuner and Workflow. + +- This functionality is experimental. +- The code is very async, considering extreme efficiency for handling many concurrent requests, + therefore, it may be hard to read. + +--------------------------------------------------------------------------------------------- + +""" + +import asyncio +import threading +import uuid +import time + +import base64 +import json +import os +import zmq +import uvicorn +import atexit +import httpx + +from loguru import logger +from pydantic import BaseModel +from fastapi import FastAPI, Header, HTTPException, Request, Body +from contextlib import asynccontextmanager +from multiprocessing import Process +from concurrent.futures import ThreadPoolExecutor + +from vllm.entrypoints.openai.protocol import ChatCompletionRequest +from openai.types.chat.chat_completion import ChatCompletion + +API_KEY_PREFIX = "sk-ajet-" + + +class InterchangeCompletionRequest(BaseModel): + completion_request: ChatCompletionRequest + agent_name: str + target_tag: str + episode_uuid: str + timeline_uuid: str + + +class HealthCheckRequest(BaseModel): + agent_name: str + target_tag: str + episode_uuid: str + timeline_uuid: str + health_check: bool = True + + +# Create FastAPI app +SERVER_SHUTDOWN_EVENT = threading.Event() +DEBUG = False +# DEBUG = True + + +context = zmq.Context() +atexit.register(context.term) + + +def get_app(max_fastapi_threads: int = 512) -> FastAPI: + @asynccontextmanager + async def lifespan(app: FastAPI): + # Startup + SERVER_SHUTDOWN_EVENT.clear() + app.state.executor = ThreadPoolExecutor(max_workers=max_fastapi_threads) + yield + # Shutdown + SERVER_SHUTDOWN_EVENT.set() + app.state.executor.shutdown(wait=False, cancel_futures=True) + + app = FastAPI(title="AJet Interchange Endpoint", lifespan=lifespan) + + def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletionRequest, episode_uuid, timeline_uuid, client_offline: threading.Event): + """run this in thread to avoid blocking main event loop""" + if DEBUG: + logger.info(f"[server] episode_uuid: {episode_uuid} | Received new chat completion request (inside thread)") + + socket = context.socket(zmq.REQ) + socket.setsockopt(zmq.RCVTIMEO, 60 * 1000) # 1 minute recv timeout + socket.connect(f"{episode_address}") + if DEBUG: + logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") + socket.send_string(int_req.model_dump_json()) + if DEBUG: + logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") + + result_str = "" + for _ in range(5): # max 5 minutes wait + try: + if DEBUG: + logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + result_str = socket.recv_string() + break + except zmq.Again as e: + if DEBUG: + logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string timeout, retrying.") + continue + + if not result_str: + raise RuntimeError(f"Failed to get response from episode_address: {episode_address} after 5 attempts.") + else: + if DEBUG: + logger.success(f"[server] episode_uuid: {episode_uuid} | recv_string done.") + result_object = ChatCompletion(**json.loads(result_str)) + return result_object + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request, authorization: str = Header(None)): + """ + OpenAI-compatible chat completions endpoint. + Receives ChatCompletionRequest and returns ChatCompletion. + """ + # Parse authorization header (base64 encoded JSON) + if not authorization: + return HTTPException(status_code=401, detail="Missing authorization header") + + try: + # Remove "Bearer " prefix if present + auth_token = authorization.replace("Bearer ", "").replace("bearer ", "").replace(API_KEY_PREFIX, "") + decoded = base64.b64decode(auth_token).decode("utf-8") + auth_data = json.loads(decoded) + + agent_name = auth_data.get("agent_name") + target_tag = auth_data.get("target_tag") + episode_uuid = auth_data.get("episode_uuid") + episode_address = auth_data.get("episode_address") + + if not all([agent_name, target_tag, episode_uuid]): + return HTTPException(status_code=401, detail="Invalid authorization data") + except Exception as e: + return HTTPException(status_code=401, detail=f"Invalid authorization header: {str(e)}") + + # Parse request body + body = await request.json() + new_req = ChatCompletionRequest.model_validate(body) + if new_req.stream: + return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.") + # Create timeline UUID + timeline_uuid = uuid.uuid4().hex + + # Add to received queue + int_req = InterchangeCompletionRequest( + completion_request=new_req, + agent_name=agent_name, + target_tag=target_tag, + episode_uuid=episode_uuid, + timeline_uuid=timeline_uuid, + ) + if DEBUG: + logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request (outside thread)") + client_offline = threading.Event() + try: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid, timeline_uuid, client_offline) + finally: + client_offline.set() + + @app.post("/reset") + async def reset(): + return {"status": "reset_complete"} + + return app + + +class InterchangeServer(Process): + def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2, max_fastapi_threads: int = 512): + super().__init__() + self.experiment_dir = experiment_dir + self.port = port + self.num_fastapi_process = num_fastapi_process + self.max_fastapi_threads = max_fastapi_threads + + def run(self): + logger.info(f"Starting Interchange Server on port {self.port} with {self.num_fastapi_process} processes and {self.max_fastapi_threads} threads per process.") + app = get_app(self.max_fastapi_threads) + + async def serve_with_monitor(): + # Start the server + config = uvicorn.Config(app=app, host="0.0.0.0", port=self.port, log_level="error", workers=self.num_fastapi_process) + server = uvicorn.Server(config) + await server.serve() + + try: + asyncio.run(serve_with_monitor()) + except KeyboardInterrupt as e: + SERVER_SHUTDOWN_EVENT.set() + raise e + + +# Convenience function for quick server startup +def start_interchange_server(config) -> int: + experiment_dir = config.ajet.experiment_dir + num_fastapi_process = config.ajet.interchange_server.num_fastapi_process + max_fastapi_threads = config.ajet.interchange_server.max_fastapi_threads + # Find a free port if not specified or invalid + port = int(os.environ.get("AJET_DAT_INTERCHANGE_PORT", -1)) + + if config.ajet.interchange_server.interchange_server_port != "auto": + port = int(config.ajet.interchange_server.interchange_server_port) + + if port <= 0: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port) + + interchange_server = InterchangeServer(experiment_dir, port, num_fastapi_process, max_fastapi_threads) + interchange_server.start() + + # Wait for server to be ready + health_url = f"http://localhost:{port}/health" + start_time = time.time() + while True: + if interchange_server.exitcode is not None: + logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}") + raise RuntimeError("Interchange server subprocess failed to start.") + if time.time() - start_time > 30: + msg = f"Interchange server subprocess failed to start within {time.time() - start_time} seconds." + logger.error(msg) + raise RuntimeError(msg) + try: + if httpx.get(health_url, timeout=0.5).status_code == 200: + break + except Exception: + # keep waiting + pass + time.sleep(1) + + # register a termination handler + if DEBUG: + logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})") + atexit.register(lambda: interchange_server.terminate()) + + # return port + return port diff --git a/ajet/utils/__init__.py b/ajet/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ajet/utils/async_utils.py b/ajet/utils/async_utils.py new file mode 100644 index 00000000..f6bed79c --- /dev/null +++ b/ajet/utils/async_utils.py @@ -0,0 +1,71 @@ +import asyncio +import concurrent.futures +from typing import Any + + +def run_async_coroutine_with_timeout(coro, timeout: int = 3600) -> Any: + """ + Run an async coroutine with a timeout, supporting both inside and outside event loops. + Args: + coro: The coroutine to run. + timeout (int): Timeout in seconds. Default is 3600. + Returns: + Any: The result of the coroutine. + Raises: + concurrent.futures.TimeoutError: If the coroutine does not finish in time. + """ + try: + asyncio.get_running_loop() + in_loop = True + except RuntimeError: + in_loop = False + if not in_loop: + final_res = asyncio.run(coro) + else: + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, coro) + try: + final_res = future.result(timeout=timeout) + except concurrent.futures.TimeoutError: + future.cancel() + raise + except Exception: + raise + return final_res + + +def apply_httpx_aclose_patch(): + try: + from openai._base_client import AsyncHttpxClientWrapper + + _original_init = AsyncHttpxClientWrapper.__init__ + + def _patched_init(self, *args, **kwargs): + try: + self._created_loop = asyncio.get_running_loop() + except RuntimeError: + self._created_loop = None + _original_init(self, *args, **kwargs) + + def _patched_del(self) -> None: + if self.is_closed: + return + + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + return + + if getattr(self, "_created_loop", None) is not None and current_loop is not self._created_loop: + return + + try: + current_loop.create_task(self.aclose()) + except Exception: + pass + + AsyncHttpxClientWrapper.__init__ = _patched_init + AsyncHttpxClientWrapper.__del__ = _patched_del + print("Applied httpx aclose patch.") + except ImportError: + pass diff --git a/ajet/utils/cleaner.py b/ajet/utils/cleaner.py new file mode 100644 index 00000000..2ed26de9 --- /dev/null +++ b/ajet/utils/cleaner.py @@ -0,0 +1,90 @@ +import os +import shlex +import subprocess +import time + + +def kill_ray_processes(): + """run ray stop command to kill ray processes""" + try: + print("Stopping ray processes...") + subprocess.run( + ["ray", "stop", "--force"], + check=False, + capture_output=True, + text=True, + ) + except Exception as e: + print(f"Failed to stop ray processes: {e}") + + +def fast_kill_by_keyword_bash( + keyword: str, + exclude_substrings=["vscode", "benchmark", "jupyter", "supervisord", "download_model"], + grace_seconds: float = 1.0, +): + """Use bash pipelines to kill processes matching keyword quickly. + + - Filters out processes containing any exclude_substrings + - Excludes current launcher process + - Sends TERM once to all PIDs, then KILL once to all PIDs after a short grace period + - Returns list of PIDs targeted + """ + self_pid = os.getpid() + + if "ray" in keyword: + kill_ray_processes() + + # Build a fast PID collector using pgrep if available; fallback to ps/grep + # We prefer pgrep -af to filter by full command and then extract PID (column 1) + exclude_filters = " ".join([f"| grep -v -F {shlex.quote(s)}" for s in exclude_substrings]) + pid_list_cmd = f"(pgrep -af -- {shlex.quote(keyword)} 2>/dev/null || true) " f"{exclude_filters} | awk '{{print $1}}' | grep -v -x {self_pid} || true" + + try: + res = subprocess.run( + ["bash", "-lc", pid_list_cmd], + capture_output=True, + text=True, + check=False, + ) + pids = [pid for pid in res.stdout.split() if pid.isdigit()] + except Exception as e: + print(f"Failed to list PIDs via bash: {e}") + pids = [] + + # Fallback to ps/grep if pgrep path produced nothing (e.g., no pgrep installed) + if not pids: + ps_pid_cmd = f"ps -eo pid,command -ww | grep -F -- {shlex.quote(keyword)} | grep -v grep " f"{exclude_filters} | awk '{{print $1}}' | grep -v -x {self_pid} || true" + try: + res2 = subprocess.run( + ["bash", "-lc", ps_pid_cmd], + capture_output=True, + text=True, + check=False, + ) + pids = [pid for pid in res2.stdout.split() if pid.isdigit()] + except Exception as e: + print(f"Failed to list PIDs via ps/grep: {e}") + pids = [] + + if not pids: + return [] + + pid_args = " ".join(pids) + try: + # Send TERM to all in one call + subprocess.run( + ["bash", "-lc", f"kill -TERM -- {pid_args} 2>/dev/null || true"], + check=False, + ) + time.sleep(grace_seconds) + # Escalate with KILL once; ignore failures for already-exited PIDs + subprocess.run( + ["bash", "-lc", f"kill -KILL -- {pid_args} 2>/dev/null || true"], + check=False, + ) + except Exception as e: + print(f"Error issuing kill commands: {e}") + + time.sleep(3.0) # wait for processes to exit + return [int(p) for p in pids] diff --git a/ajet/utils/color_hsl.py b/ajet/utils/color_hsl.py new file mode 100644 index 00000000..3b423e44 --- /dev/null +++ b/ajet/utils/color_hsl.py @@ -0,0 +1,41 @@ +import colorsys + + +def adjust_color_hsl(base_color, logprob): + """ + Adjust color saturation using the HSL color space based on log probability. + Args: + base_color (str): Hexadecimal color string (e.g., '#ff0000'). + logprob (float): Log probability value to determine saturation. + Returns: + str: Adjusted hexadecimal color string. + """ + # Map logprob to a saturation adjustment factor in the range [sat_min, sat_max] + sat_min = 0.333 + sat_max = 1.0 + lp_min = -7 + lp_max = 0 + + if logprob <= lp_min: + saturation_factor = sat_min + elif logprob >= 0: + saturation_factor = sat_max + else: + saturation_factor = sat_min + (logprob - lp_min) / (lp_max - lp_min) * (sat_max - sat_min) + + # Convert hexadecimal color to RGB + r = int(base_color[1:3], 16) / 255.0 + g = int(base_color[3:5], 16) / 255.0 + b = int(base_color[5:7], 16) / 255.0 + + # Convert to HSL + h, l, s = colorsys.rgb_to_hls(r, g, b) + + # Adjust saturation + s_adjusted = s * saturation_factor + + # Convert back to RGB + r_adjusted, g_adjusted, b_adjusted = colorsys.hls_to_rgb(h, l, s_adjusted) + + # Convert back to hexadecimal + return f"#{int(r_adjusted*255):02x}{int(g_adjusted*255):02x}{int(b_adjusted*255):02x}" diff --git a/ajet/utils/compute_madness.py b/ajet/utils/compute_madness.py new file mode 100644 index 00000000..a028dc57 --- /dev/null +++ b/ajet/utils/compute_madness.py @@ -0,0 +1,145 @@ +# flake8: noqa: W605 +import re +from functools import cache + +# Regex fragments for each whitelist category +WHITE_LIST_REGEX_PARTS = { + # Common symbols + "common_symbols": "‘’“”–—…•™©®°±µ′″℉℃·×", + # Chinese punctuation + "chinese_punct": ",。!?、;:“”‘’()【】《》()——……「」『』", + # Emoji ranges + "emoji": ("\U0001F300-\U0001F5FF" "\U0001F600-\U0001F64F" "\U0001F680-\U0001F6FF" "\U0001F700-\U0001F77F" "\U0001F780-\U0001F7FF" "\U0001F800-\U0001F8FF" "\U0001F900-\U0001F9FF" "\U0001FA00-\U0001FA6F" "\U0001FA70-\U0001FAFF" "\u2702-\u27B0" "\u24C2-\U0001F251"), + # Chinese characters + "chinese": ("\u4E00-\u9FFF" "\u3400-\u4DBF" "\U00020000-\U0002A6DF" "\U0002A700-\U0002B73F" "\U0002B740-\U0002B81F" "\U0002B820-\U0002CEAF" "\uF900-\uFAFF" "\U0002F800-\U0002FA1F"), +} + + +@cache +def build_pattern(white_list): + """Build a regex based on the provided whitelist categories.""" + allowed_parts = ["\x00-\x7F"] # All ASCII + for name in white_list: + if name in WHITE_LIST_REGEX_PARTS: + allowed_parts.append(WHITE_LIST_REGEX_PARTS[name]) + # Merge allowed ranges into one character class, then use a negated class to match disallowed characters + allowed_class = "".join(allowed_parts) + pattern = f"[^{allowed_class}]" # Match disallowed characters + return re.compile(pattern) + + +def has_non_ascii(text, white_list=("common_symbols", "emoji", "chinese", "chinese_punct")): + pattern = build_pattern(white_list) + return bool(pattern.search(text)) + + +def has_repeat(token, remember_n_words=5, patience_max=10): + record_words = [] + patience = patience_max + for char in token: + if char not in record_words: + record_words += [char] + if len(record_words) > remember_n_words: + record_words = record_words[1:] + patience = patience_max + else: + patience -= 1 + if patience <= 0: + return True + return False + + +def compute_string_madness(completion, detail=False, checklist=["nonsense"]) -> float: + all_reward = 0.0 + if ("nonsense" in checklist) and ("non_ascii" in checklist): + all_reward += compute_string_madness_char(completion, detail=detail) + elif ("nonsense" in checklist) and ("non_ascii" not in checklist): + all_reward += compute_string_madness_char(completion, detail=detail, skip_non_ascii=True) + if "format_type_1" in checklist: + all_reward += compute_string_madness_format(completion, detail=detail, format_type="type_1") + + return all_reward + + +def compute_string_madness_format(completion, detail, format_type) -> float: + if format_type == "type_1": + """ + + ... + + ```python + code + ``` + + """ + # Check that and appear exactly once and in order + if not completion.strip().startswith(r""): + # print("not start with ") + return -1.0 + if completion.count(r"") != 1 or completion.count(r"") != 1: + # print("not one think") + return -1.0 + if completion.index(r"") > completion.index(r""): + # print("think tag order wrong") + return -1.0 + # remove think part + think_part = completion[completion.index(r"") : completion.index(r"") + len(r"")] + rest_part = completion.replace(think_part, "") + # Check that ```python and ``` appear exactly once and in order + if not rest_part.strip().startswith(r"```python"): + # print("not start with ```python") + return -1.0 + if not rest_part.strip().endswith(r"```"): + # print("not end with ```") + return -1.0 + if rest_part.count(r"```python") != 1 or rest_part.count(r"```") != 2: + # print("not one ```python") + return -1.0 + if rest_part.index(r"```python") > rest_part.rindex(r"```"): + # print("``` tag order wrong") + return -1.0 + return 0.0 + else: + raise NotImplementedError(f"format_type {format_type} not implemented") + + +def compute_string_madness_char(completion, detail=False, skip_non_ascii=False) -> float: + # if detail: + # result = { + # "has_non_ascii": has_non_ascii(completion), + # "has_repeat": has_repeat(completion.split(), remember_n_words=5, patience_max=10), + # "has_repeat_x": has_repeat(completion, remember_n_words=4, patience_max=200), + # "has_wrong_sp_token": "<|im_start|>" in completion, + # # 'non_ascii': {ch for ch in completion if ord(ch) > 127} + # } + # if has_non_ascii(completion): + # for char in completion: + # if has_non_ascii(char): + # print(f"---") + # print(f"found non-ascii char: {char} ord={ord(char)}") + # print(result) + # return result + + if "<|im_start|>" in completion: + return -1.0 + + if skip_non_ascii: + if has_non_ascii(completion): + return -1.0 + + if has_repeat(completion.split(), remember_n_words=5, patience_max=10): + return -1.0 + + if has_repeat(completion, remember_n_words=4, patience_max=200): + return -1.0 + + return 0 + + +def repetition_penalty_reward_scalar_debug(completion): + for i in range(len(completion)): + p = completion[:i] + result = compute_string_madness(p) + if result != 0: + return completion + return "" diff --git a/ajet/utils/config_computer.py b/ajet/utils/config_computer.py new file mode 100644 index 00000000..c3a71d7e --- /dev/null +++ b/ajet/utils/config_computer.py @@ -0,0 +1,217 @@ +import ast +import re +from typing import Any, Callable, Dict, List, Tuple + + +# Abstract Syntax Tree Visitor to extract variable names +class AstStructureExtractor(ast.NodeVisitor): + """Visitor pattern to extract all keys (variable names)""" + + def __init__(self): + self.keys = set() + # Define builtin function list to avoid dependency on different behaviors of __builtins__ + self.builtin_names = { + "min", + "max", + "abs", + "round", + "int", + "float", + "sum", + "len", + "str", + "bool", + "list", + "dict", + "tuple", + "set", + "range", + "enumerate", + "zip", + "map", + "filter", + "sorted", + "reversed", + "all", + "any", + "bin", + "hex", + "oct", + "chr", + "ord", + "pow", + "divmod", + "type", + "isinstance", + "hasattr", + "getattr", + "setattr", + "delattr", + "callable", + "iter", + "next", + # Add other potential builtin functions as needed + } + + def visit_Name(self, node): + # Collect all variable names, excluding builtin functions + if node.id not in self.builtin_names: + self.keys.add(node.id) + self.generic_visit(node) + + def visit_Attribute(self, node): + # Handle attribute access like "ajet.rollout.max_env_worker" + # Reconstruct the full attribute path + full_key = self._get_full_attribute_name(node) + if full_key and not self._is_builtin_attribute(full_key): + self.keys.add(full_key) + # Don't call generic_visit to avoid duplicate processing of child nodes + + def _get_full_attribute_name(self, node): + """Recursively get the full attribute name""" + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Attribute): + base = self._get_full_attribute_name(node.value) + if base: + return f"{base}.{node.attr}" + return None + + def _is_builtin_attribute(self, attr_name): + """Check if it's an attribute of a builtin module (like math.sin)""" + # Can extend this list as needed + builtin_modules = {"math", "os", "sys", "json", "re", "datetime"} + parts = attr_name.split(".") + return len(parts) > 1 and parts[0] in builtin_modules + + +def split_keys_and_operators(operation_str: str, preserved_field: List[str] = []) -> Tuple[List[str], Callable[[Dict[str, Any]], Any]]: + """ + Parse expression string using AST and extract keys and operators + + Input example: (min(ajet.rollout.max_env_worker // ajet.rollout.n_vllm_engine, 64)) + Output example: (['ajet.rollout.max_env_worker', 'ajet.rollout.n_vllm_engine'], ) + """ + + # Parse the expression + try: + tree = ast.parse(operation_str, mode="eval") + except SyntaxError as e: + raise ValueError(f"Expression syntax error: {operation_str}") from e + + # use Abstract Syntax Tree to extract all keys + extractor = AstStructureExtractor() + extractor.visit(tree) + keys = sorted(list(extractor.keys)) + + # Create evaluation function + def eval_func(values: Dict[str, Any]) -> Any: + # Check if all required keys exist + missing_keys = [key for key in keys if key not in values] + if missing_keys: + raise ValueError(f"Missing required keys: {missing_keys}") + + # Create mapping from key names to safe variable names + key_mapping = {} + safe_expression = operation_str + + # Sort by key length in descending order to replace longer keys first + sorted_keys = sorted(keys, key=len, reverse=True) + + for i, key in enumerate(sorted_keys): + # Create a safe variable name for each key + safe_var = f"var_{i}" + key_mapping[safe_var] = values[key] + + # Use regex to precisely match and replace key names + # Ensure no partial matching (e.g., won't match "a.b.c" in "a.b.cd") + pattern = re.escape(key) + r"(?![a-zA-Z0-9_.])" + safe_expression = re.sub(pattern, safe_var, safe_expression) + + # Create a safe namespace for evaluation + namespace = { + "__builtins__": { + "min": min, + "max": max, + "abs": abs, + "round": round, + "int": int, + "float": float, + "sum": sum, + "len": len, + # Can add more safe builtin functions as needed + } + } + + # Add mapped variables to namespace + namespace.update(key_mapping) + + # Evaluate the expression + try: + result = eval(safe_expression, namespace) + return result + except Exception as e: + raise RuntimeError(f"Error evaluating expression '{operation_str}': {e}") from e + + # print(f"Extracted keys: {keys}") + return keys, eval_func + + +# # Test examples +# if __name__ == "__main__": +# # Example 1 +# operation_str1 = "(ajet.data.train_batch_size * ajet.rollout.num_repeat * ajet.rollout.multi_turn.expected_steps)" +# known_operators1 = [] + +# keys1, func1 = split_keys_and_operators(operation_str1) +# print("Example 1:") +# print(f"Extracted keys: {keys1}") + +# values1 = { +# "ajet.data.train_batch_size": 32, +# "ajet.rollout.num_repeat": 4, +# "ajet.rollout.multi_turn.expected_steps": 10, +# } +# result1 = func1(values1) +# print(f"Computed result: {result1}") # 32 * 4 * 10 = 1280 +# print() + +# # Example 2 +# operation_str2 = "(ajet.rollout.max_env_worker // ajet.rollout.n_vllm_engine)" +# known_operators2 = [] + +# keys2, func2 = split_keys_and_operators(operation_str2) +# print("Example 2:") +# print(f"Extracted keys: {keys2}") + +# values2 = {"ajet.rollout.max_env_worker": 100, "ajet.rollout.n_vllm_engine": 8} +# result2 = func2(values2) +# print(f"Computed result: {result2}") # 100 // 8 = 12 +# print() + +# # Example 3: Mixed operators +# operation_str3 = "(a * b / c + d - e)" +# known_operators3 = [] + +# keys3, func3 = split_keys_and_operators(operation_str3) +# print("Example 3:") +# print(f"Extracted keys: {keys3}") + +# values3 = {"a": 100, "b": 5, "c": 10, "d": 20, "e": 5} +# result3 = func3(values3) +# print(f"Computed result: {result3}") # 100 * 5 / 10 + 20 - 5 = 65.0 + +# # Example 4 +# operation_str4 = "(min(ajet.rollout.max_env_worker // ajet.rollout.n_vllm_engine, 64))" +# known_operators4 = [] + +# keys4, func4 = split_keys_and_operators(operation_str4) +# print("Example 4:") +# print(f"Extracted keys: {keys4}") + +# values4 = { +# "ajet.rollout.max_env_worker": 512, +# "ajet.rollout.n_vllm_engine": 4, +# } +# result4 = func4(values4) +# print(f"Computed result: {result4}") # 64 diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py new file mode 100644 index 00000000..8983f733 --- /dev/null +++ b/ajet/utils/config_utils.py @@ -0,0 +1,307 @@ +import os +import shutil +import time +from functools import cache + +import yaml +from beast_logger import print_dict +from hydra import compose, initialize +from loguru import logger +from omegaconf import DictConfig + +from ajet.utils.config_computer import split_keys_and_operators + + +def read_ajet_config(yaml_fp): + """Load a Hydra configuration relative to this module.""" + yaml_fp = os.path.relpath(yaml_fp, os.path.dirname(__file__)) # do not try to understand this line, hydra is too weird + + def load_hydra_config(config_path: str, config_name: str) -> DictConfig: + with initialize(config_path=config_path, version_base=None): + cfg = compose(config_name=config_name, overrides=[]) + return cfg + + dir_path = os.path.dirname(yaml_fp) + file_name = os.path.basename(yaml_fp) + return load_hydra_config(config_path=dir_path, config_name=file_name) + + +@cache +def read_ajet_config_with_cache(yaml_fp): + """Load a Hydra configuration relative to this module with caching.""" + return read_ajet_config(yaml_fp) + + +def dump_yaml_config(cfg: DictConfig, yaml_fp: str): + """Persist the provided OmegaConf config to ``yaml_fp``.""" + from omegaconf import OmegaConf + + with open(yaml_fp, "w") as f: + OmegaConf.save(cfg, f) + return yaml_fp + + +def _dive_to_fetch_value(config, dotted_key): + keys = dotted_key.split(".") + value = config + for key in keys: + value = value.get(key, None) + if value is None: + break + if value is None: + raise ValueError(f"[Warning]: Cannot find value for key: {dotted_key} in {config}") + return value + + +def _dive_to_set_value(config, dotted_key, value): + keys = dotted_key.split(".") + sub_config = config + for key in keys[:-1]: + if key not in sub_config: + sub_config[key] = {} + sub_config = sub_config[key] + sub_config[keys[-1]] = value + + +def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone): + """Align configuration values based on a conversion map. + + Parameters + ---------- + from_config_fp : str + Source YAML path to read values from. + to_config_fp : str + Destination YAML path that is updated in place. + convertion_json_fg : str + JSON path mapping dotted keys between configs. + backbone : str + Backbone identifier used for framework-specific alignment. + """ + # read yaml files + with open(from_config_fp, "r") as file: + from_config = yaml.safe_load(file) + with open(to_config_fp, "r") as file: + to_config = yaml.safe_load(file) + + # read convertion json + import json + + with open(convertion_json_fg, "r") as file: + convertion_json = json.load(file) + + logger.success("----------------------------------------------------") + # align trinity.* to to_config + if ("trinity" in from_config) and backbone == "trinity": + trinity_config = from_config["trinity"] + + def recursive_copy(src_dict, dst_dict, parent_key=""): + for key, value in src_dict.items(): + full_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, dict): + if key not in dst_dict: + dst_dict[key] = {} + recursive_copy(value, dst_dict[key], full_key) + else: + dst_dict[key] = value + + recursive_copy(trinity_config, to_config) + + # align based on convertion_json + for from_key, to_keys in convertion_json.items(): + if from_key.startswith("("): + # special argument that need A.S.T. computation + keys_array, config_computer = split_keys_and_operators(from_key, []) + value = config_computer({k: _dive_to_fetch_value(from_config, k) for k in keys_array}) + else: + # normal argument + value = _dive_to_fetch_value(from_config, from_key) + + # multiple to_keys support + to_keys = to_keys if isinstance(to_keys, list) else [to_keys] + + # set and override config value + for to_key in to_keys: + _dive_to_set_value(to_config, to_key, value) + logger.success(f"[Note]: Aligned parameter from [{from_key}] to [{to_key}] with value: [{value}]") + + # backbone specific safe guard + to_config = config_safe_guard(to_config, backbone) + + # save to_config_fp + with open(to_config_fp, "w") as file: + yaml.dump(to_config, file) + + # logger.success(f"Saved aligned configuration to {to_config_fp}") + print_dict({"Note": f"Saved aligned configuration to {to_config_fp}"}, header="Final Configuration") + + +def config_safe_guard(config: dict, backbone: str) -> dict: + # special: logger + if backbone == "verl" and isinstance(config["trainer"]["logger"], str): + config["trainer"]["logger"] = ["console", config["trainer"]["logger"]] + + # special: trinity train_batch_size + if backbone == "trinity": + train_batch_size = config["buffer"]["train_batch_size"] + world_size = config["cluster"]["gpu_per_node"] * config["cluster"]["node_num"] + vllm_world_size = config["explorer"]["rollout_model"]["tensor_parallel_size"] * config["explorer"]["rollout_model"]["engine_num"] + fsdp_world_size = world_size - vllm_world_size + + # if train_batch_size % fsdp_world_size != 0, train_batch_size + until divisible + if fsdp_world_size > 0 and train_batch_size % fsdp_world_size != 0: + new_train_batch_size = train_batch_size + while new_train_batch_size % fsdp_world_size != 0: + new_train_batch_size += 1 + logger.warning(f"[Warning]: trinity backbone detected, but train_batch_size {train_batch_size} is not divisible by fsdp_world_size {fsdp_world_size}. Automatically adjust train_batch_size to {new_train_batch_size}.") + config["buffer"]["train_batch_size"] = new_train_batch_size + + return config + + +def read_ajet_hierarchical_config(yaml_fp, exp_name, backbone, write_to=None, exp_dir="saved_experiments"): + if yaml_fp is None: + config = { + "ajet": {}, + "hydra": { + "searchpath": [ + "file://ajet/default_config", + "file://ajet/default_config/verl", + "file://ajet/default_config/trinity", + ] + }, + "defaults": [ + "verl_default", + "trinity_default", + "ajet_default", + "_self_", + ], + } + else: + with open(yaml_fp, "r") as file: + config = yaml.safe_load(file) + config["ajet"]["experiment_name"] = exp_name + config["ajet"]["experiment_dir"] = os.path.join(exp_dir, exp_name) + config["ajet"]["backbone"] = backbone + + # remove extra config of verl for trinity + if backbone == "debug": + if "trinity_default" in config["defaults"]: + config["defaults"].remove("trinity_default") + config["hydra"]["searchpath"].remove("file://ajet/default_config/trinity") + # remove extra config of verl for trinity + if backbone == "trinity": + if "verl_default" in config["defaults"]: + config["defaults"].remove("verl_default") + config["hydra"]["searchpath"].remove("file://ajet/default_config/verl") + # remove extra config of trinity for verl + if backbone == "verl": # or args.backbone == "debug" + if "trinity_default" in config["defaults"]: + config["defaults"].remove("trinity_default") + config["hydra"]["searchpath"].remove("file://ajet/default_config/trinity") + + if write_to: + with open(write_to, "w") as file: + yaml.dump(config, file) + return config + + +def expand_ajet_hierarchical_config(config, write_to=None): + # create temp yaml file + import tempfile + + with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml: + yaml_path = temp_yaml.name + with open(yaml_path, "w") as file: + yaml.dump(config, file) + full_config = read_ajet_config(yaml_path) + yaml_path = dump_yaml_config(full_config, yaml_fp=yaml_path) + # put inherit info back + with open(yaml_path, "r") as file: + config_final = yaml.safe_load(file) + config_final["defaults"] = config["defaults"] + config_final["hydra"] = config["hydra"] + + if write_to: + with open(write_to, "w") as file: + yaml.dump(config_final, file) + + return config_final + + +def prepare_experiment_config(yaml_path, exp_dir, backbone): + """ + Prepare experiment configuration by reading YAML, setting up backup directories, + and copying necessary files for the experiment. + + Args: + yaml_path: Path to the YAML configuration file + exp_dir: Directory where experiment artifacts and backups should be stored + backbone: Backbone identifier that controls config munging + + Returns: + tuple: (yaml_backup_dst, exe_exp_base, exp_name, config_final) + """ + assert yaml_path.endswith(".yaml"), "Configuration file must be a YAML file" + exp_base = os.path.dirname(yaml_path) + + if not os.path.exists(exp_base): + raise FileNotFoundError(f"Configuration file not found: {exp_base}") + + ## 0. read yaml & get experiment_name + with open(yaml_path, "r") as file: + config = yaml.safe_load(file) + try: + exp_name = config.get("ajet").get("experiment_name") + except Exception: + raise ValueError(f"Please set ajet field in yaml file. Current yaml:\n{config}") + if exp_name is None or exp_name == "read_yaml_name": + if exp_name is not None: + exp_name = exp_name.replace("|", "-") + exp_name = os.path.basename(yaml_path).replace(".yaml", "") + # add timestamp to exp_name + timestamp = time.strftime("%Y%m%d_%H%M", time.localtime()) + exp_name = f"{exp_name}_{timestamp}" + else: + exp_name = exp_name.replace("|", "-") + + backup_dir = os.path.join(exp_dir, exp_name, "backup") + yaml_backup_dst = os.path.join(exp_dir, exp_name, "yaml_backup.yaml") + yaml_backup_dst = os.path.abspath(yaml_backup_dst) + exe_exp_base = os.path.dirname(yaml_backup_dst) + + logger.info("----------------------------------------") + logger.info(f"Experiment Name: {exp_name}") + logger.info(f"Experiment Backup Dir: {backup_dir}") + logger.info(f"Experiment Yaml Dir: {yaml_backup_dst}") + logger.info("----------------------------------------") + + ## 1. check exp_base/backup exist + if not os.path.exists(backup_dir): + os.makedirs(backup_dir) + else: + total_seconds = 5 + for i in range(total_seconds): + logger.warning(f"Warning: backup directory already exists, we will automatically ignore this after {total_seconds - i} seconds...") + time.sleep(1) + + ## 2. copy files to backup + BACK_TARGETS = os.environ.get("BACK_TARGETS", "").split(",") + BACK_TARGETS = [p for p in BACK_TARGETS if os.path.exists(p)] + + for backup_target in BACK_TARGETS: + logger.info(f"Copying {backup_target} to {os.path.join(backup_dir, os.path.basename(backup_target))}") + shutil.copytree( + backup_target, + os.path.join(backup_dir, os.path.basename(backup_target)), + dirs_exist_ok=True, + ) + + ## 3. copy yaml to backup + yaml_backup_src = yaml_path + shutil.copyfile(yaml_backup_src, yaml_backup_dst) + + ## 4. edit new yaml + config = read_ajet_hierarchical_config(yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir) + config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst) + + return yaml_backup_dst, exe_exp_base, exp_name, config_final diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py new file mode 100644 index 00000000..91fdf736 --- /dev/null +++ b/ajet/utils/core_env_vars.py @@ -0,0 +1,60 @@ +import os +from pathlib import Path + +from beast_logger import print_dict +from dotenv import load_dotenv +from ajet.utils.networking import find_free_port, get_host_ip + + +def get_runtime_env(config, is_trinity: bool = False) -> dict: + if os.path.exists(".env"): + load_dotenv(".env") + + master_node_ip = get_host_ip(os.environ.get("NETWORK_INTERFACE", None)) + if config.ajet.trainer_common.nnodes == 1: + master_node_ip = "localhost" + else: + if config.ajet.enable_experimental_interchange_server: + if config.ajet.interchange_server.interchange_method == "ipc": + raise ValueError("IPC interchange method is not supported for multi-node setup. Please set `ajet.interchange_server.interchange_method: tcp` ") + + runtime_env = { + "env_vars": { + "VLLM_USE_V1": "1", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + "TOKENIZERS_PARALLELISM": "true", + # use ajet.backbone as plugin directory + "TRINITY_PLUGIN_DIRS": str((Path(__file__).parent.parent / "backbone").resolve()), + # "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", + "SWANLAB_API_KEY": os.getenv("SWANLAB_API_KEY", ""), + "AJET_CONFIG_REDIRECT": os.getenv("AJET_CONFIG_REDIRECT", ""), + "AJET_DAT_INTERCHANGE_PORT": str(find_free_port()), + "MASTER_NODE_IP": master_node_ip, + } + } + + optional_env_vars = [ + "RAY_record_task_actor_creation_sites", + "BEST_LOGGER_WEB_SERVICE_URL", + "AJET_GIT_HASH", + "AJET_REQ_TXT", + "AJET_BENCHMARK_NAME", + "FINANCE_MCP_URL", + # API Keys for RM Gallery and other services + "DASHSCOPE_API_KEY", + "OPENAI_API_KEY", + "OPENAI_BASE_URL", + "API_KEY", + "BASE_URL", + ] + + for var in optional_env_vars: + if os.getenv(var): + runtime_env["env_vars"].update({var: os.getenv(var, "")}) + + if is_trinity: + assert "AJET_CONFIG_REDIRECT" in runtime_env["env_vars"] + + print_dict(runtime_env["env_vars"], "runtime_env") + return runtime_env diff --git a/ajet/utils/dynamic_import.py b/ajet/utils/dynamic_import.py new file mode 100644 index 00000000..1d36b4c6 --- /dev/null +++ b/ajet/utils/dynamic_import.py @@ -0,0 +1,89 @@ +import importlib +import importlib.util +import os +import sys +import threading +from typing import Any, Callable, Union + + +def cls_to_path(obj_or_path: Union[str, Callable[..., Any]]) -> str: + """Convert a callable to the ``module->name`` string expected by dynamic_import.""" + + if isinstance(obj_or_path, str): + return obj_or_path + module = getattr(obj_or_path, "__module__", None) + name = getattr(obj_or_path, "__name__", None) + if module and name: + return f"{module}->{name}" + raise ValueError("Object must be a dotted string or a callable with __module__ and __name__.") + + +def _dynamic_import(module_class_str: str): + """ + Dynamic import of class from module + Supports two formats: + 1. module.path->ClassName (dot-separated module path) + 2. path/to/module.py->ClassName (file path format, can be absolute or relative) + """ + module_str, class_name = module_class_str.split("->") + + # Use .py-> as identifier for file path format + if ".py->" in module_class_str: + # Handle file path format + file_path = module_str + # Split module name + module_name = os.path.splitext(os.path.basename(file_path))[0] + + # check if module already loaded + if module_name in sys.modules: + module = sys.modules[module_name] + else: + # Convert to absolute path + if not os.path.isabs(file_path): + file_path = os.path.abspath(file_path) + if not os.path.exists(file_path): + raise ImportError(f"Module file not found: {file_path}") + # Load module from file using importlib.util + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot create module spec for: {file_path}") + + module = importlib.util.module_from_spec(spec) + # Load module from file using importlib.util + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot create module spec for: {file_path}") + + module = importlib.util.module_from_spec(spec) + + # Add module to sys.modules BEFORE execution to avoid duplicate loading + sys.modules[module_name] = module + + # Execute module + spec.loader.exec_module(module) + + else: + # Standard module path format + module = importlib.import_module(module_str) + + # Get class + try: + protocol_cls = getattr(module, class_name) + except Exception as e: + raise ImportError(f"Cannot import class {class_name} from module {module_str}: {e}") from e + return protocol_cls + + +_import_lock = threading.RLock() + + +def dynamic_import(module_class_str: str): + """ + Thread-safe dynamic import of a class from a module string. + Args: + module_class_str (str): Module and class string, e.g., 'module.path->ClassName' or 'path/to/module.py->ClassName'. + Returns: + type: The imported class type. + """ + with _import_lock: + return _dynamic_import(module_class_str) diff --git a/ajet/utils/embedding_client.py b/ajet/utils/embedding_client.py new file mode 100644 index 00000000..1fa89426 --- /dev/null +++ b/ajet/utils/embedding_client.py @@ -0,0 +1,338 @@ +import os +import uuid +from typing import Any, Dict, List, Optional, Sequence, Union, cast + +import httpx +from loguru import logger +from tenacity import retry, stop_after_attempt, wait_exponential + + +class OpenAIEmbeddingClient: + """ + Client class for OpenAI Embedding API. + Supports calling embedding APIs in OpenAI format with rate limiting. + """ + + def __init__( + self, + api_key: str, + base_url: str = "https://api.openai.com/v1", + model_name: str = "text-embedding-ada-002", + rate_limit_calls: int = 60, + rate_limit_window: int = 60, + ): + """ + Initializes the OpenAI Embedding API client. + + Args: + api_key (str): The API key for authentication. + base_url (str): The base URL for the API, defaulting to the official OpenAI address. + model_name (str): The name of the model to use, defaulting to text-embedding-ada-002. + rate_limit_calls (int): The number of allowed calls within the rate limit window, defaulting to 60. + rate_limit_window (int): The time window in seconds for the rate limit, defaulting to 60 seconds. + """ + self.api_key = api_key + self.base_url = base_url.rstrip("/") # ⭐ Ensures the base URL does not end with a trailing slash + self.model_name = model_name + # Set up the request headers + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", # ⭐ Constructs the authorization header using the provided API key + } + + logger.info(f"init OpenAI Embedding client, quota: {rate_limit_calls} times/{rate_limit_window}s") + + @retry(stop=stop_after_attempt(4), wait=wait_exponential(multiplier=1, min=4, max=60)) + def get_embeddings( + self, + texts: Union[str, Sequence[str]], + model: Optional[str] = None, + encoding_format: str = "float", + dimensions: Optional[int] = None, + user: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Fetches the embedding vectors for the provided texts with rate limiting. + + Args: + texts (Union[str, Sequence[str]]): Text(s) for which to fetch the embeddings, can be a single string or a list of strings. + model (Optional[str]): Name of the model to use; if not specified, the model set during initialization is used. + encoding_format (str): Encoding format for the embeddings, default is "float". + dimensions (Optional[int]): Output dimensionality (supported by some models). + user (Optional[str]): User identifier. + + Returns: + Dict[str, Any]: The API response as a dictionary. + + Raises: + requests.RequestException: If there is an issue with the request. + ValueError: If the input parameters are invalid. + """ + # Rate limiting control + # self.rate_limiter.acquire() # ⭐ Acquires a token from the rate limiter to ensure the request does not exceed the allowed rate + + # Parameter validation + if not texts: + raise ValueError("texts cannot be empty") + + # Construct the request payload + payload = { + "input": texts, + "model": model or self.model_name, + "encoding_format": encoding_format, + } + + # Add optional parameters + if dimensions is not None: + payload["dimensions"] = dimensions + if user is not None: + payload["user"] = user + + # Send the request + url = f"{self.base_url}/embeddings" + + try: + with httpx.Client(timeout=60.0) as client: + response = client.post( + url, + headers=self.headers, + json=payload, + ) + if not response.is_success: + logger.error(f"failed to request embedding: {response.status_code} {response.reason_phrase}") + try: + logger.error(f"err json: {response.json()}") + except Exception: + logger.error("err json: ") + response.raise_for_status() + + return response.json() + + except httpx.RequestError as e: + raise httpx.RequestError(f"failed to request embedding: {e}") + + def get_single_embedding(self, text: str, **kwargs) -> List[float]: + """ + Retrieves the embedding vector for a single piece of text. This is a simplified method that wraps around the `get_embeddings` method. + + Args: + text (str): The text for which to retrieve the embedding vector. + **kwargs: Additional arguments to pass to the `get_embeddings` method. + + Returns: + List[float]: The embedding vector for the provided text. + """ + result = self.get_embeddings(text, **kwargs) # ⭐ Calls the get_embeddings method with the given text and additional arguments + return result["data"][0]["embedding"] + + def get_multiple_embeddings(self, texts: Sequence[str], **kwargs) -> List[List[float]]: + """ + Retrieves the embedding vectors for multiple texts (simplified method). + + Args: + texts (Sequence[str]): A list of texts to get the embedding vectors for. + **kwargs: Additional arguments to pass to the `get_embeddings` method. + + Returns: + List[List[float]]: A list of embedding vectors. + """ + result = self.get_embeddings(texts, **kwargs) # ⭐ Calls the `get_embeddings` method with provided texts and additional arguments + return [item["embedding"] for item in result["data"]] # ⭐ Extracts the 'embedding' field from each item in the returned data + + def set_model(self, model_name: str): + """ + Sets the default model name for the API client. + + Args: + model_name (str): The name of the model to be used. + """ + self.model_name = model_name # ⭐ Set the model name + + def set_base_url(self, base_url: str): + """ + Sets the base URL for the API, ensuring it does not end with a trailing slash. + + Args: + base_url (str): The base URL for the API. + """ + self.base_url = base_url.rstrip("/") # ⭐ Remove trailing slash if present + + def set_api_key(self, api_key: str): + """ + Sets the API key and updates the authorization header for the API requests. + + Args: + api_key (str): The API key for authentication. + """ + self.api_key = api_key + self.headers["Authorization"] = f"Bearer {self.api_key}" # ⭐ Update the authorization header + + +class EmbeddingClient: + def __init__( + self, + similarity_threshold: float, + base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1", + api_key: Optional[str] = None, + model: str = "text-embedding-v4", + chroma_db_path: str = "./chroma_db", + collection_name: str = "trajectories", + ): + import chromadb + from chromadb.config import Settings + + api_key = api_key or os.getenv("DASHSCOPE_API_KEY") + assert api_key is not None, "DASHSCOPE_API_KEY is required" + + self._client = OpenAIEmbeddingClient(api_key=api_key, base_url=base_url, model_name=model) + self.similarity_threshold = similarity_threshold + self._chroma_client = chromadb.PersistentClient(path=chroma_db_path, settings=Settings(anonymized_telemetry=False)) + + self._collection = self._chroma_client.get_or_create_collection(name=collection_name, metadata={"hnsw:space": "cosine"}) + + def add(self, text: str, id: int): + """ + Add text and ID to ChromaDB + """ + embedding = self._client.get_single_embedding(text) + + chroma_id = f"doc_{id}_{uuid.uuid4().hex[:8]}" + + self._collection.add( + embeddings=[embedding], + documents=[text], + ids=[chroma_id], + metadatas=[{"original_id": id, "text_length": len(text)}], + ) + + def find_by_text(self, text: str) -> Optional[int]: + """ + Find a similar text in ChromaDB, return the corresponding ID + """ + if self._collection.count() == 0: + return None + + query_embedding = self._client.get_single_embedding(text) + + results = self._collection.query( + query_embeddings=[query_embedding], + n_results=1, # only the top result + include=["documents", "metadatas", "distances"], + ) + + if not results["ids"] or not results["ids"][0]: + return None + + distance = results["distances"][0][0] # type: ignore + similarity = 1 - distance + + if similarity >= self.similarity_threshold: + # Get the original_id from metadata instead of using _reverse_id_mapping + metadata = results["metadatas"][0][0] # type: ignore + return cast(int | None, metadata.get("original_id")) + else: + return None + + def find_top_k_by_text(self, text: str, k: int = 5) -> list[tuple[int, float, str]]: + """ + Find the top k similar documents + """ + if self._collection.count() == 0: + return [] + + query_embedding = self._client.get_single_embedding(text) + + results = self._collection.query( + query_embeddings=[query_embedding], + n_results=min(k, self._collection.count()), + include=["documents", "metadatas", "distances"], + ) + + if not results["ids"] or not results["ids"][0]: + return [] + + result_list = [] + for i in range(len(results["ids"][0])): + distance = results["distances"][0][i] # type: ignore + similarity = 1 - distance + document = results["documents"][0][i] # type: ignore + # Get the original_id from metadata instead of using _reverse_id_mapping + metadata = results["metadatas"][0][i] # type: ignore + original_id = metadata.get("original_id") + + if original_id is not None: + result_list.append((original_id, similarity, document)) + + return result_list + + def _embedding(self, texts: Sequence[str], bs=10) -> list[list[float]]: + """ + Get the embedding of texts + """ + res: list[list[float]] = [] + for i in range(0, len(texts), bs): + res.extend(self._client.get_multiple_embeddings(texts[i : i + bs])) + + return res + + def get_all_stored_texts(self) -> dict[int, str]: + """ + Get all stored texts + """ + all_data = self._collection.get(include=["documents", "metadatas"]) + result = {} + + if all_data["ids"]: + for i in range(len(all_data["ids"])): + # Get the original_id from metadata instead of using _reverse_id_mapping + metadata = all_data["metadatas"][i] # type: ignore + original_id = metadata.get("original_id") + if original_id is not None: + result[original_id] = all_data["documents"][i] # type: ignore + + return result + + def exists(self, id: int) -> bool: + """ + Check if the ID exists + """ + results = self._collection.get(where={"original_id": id}, include=[]) + return bool(results["ids"]) + + def remove(self, id: int) -> bool: + """ + Remove the text and embedding vector of the specified ID + """ + # Find the chroma_id by querying for the document with the specified original_id + results = self._collection.get(where={"original_id": id}, include=["metadatas"]) + + if not results["ids"] or not results["ids"][0]: + return False + + chroma_id = results["ids"][0] + + try: + self._collection.delete(ids=[chroma_id]) + return True + except Exception: + return False + + def clear(self): + """clear all stored texts and embeddings""" + try: + self._chroma_client.delete_collection(self._collection.name) + self._collection = self._chroma_client.get_or_create_collection(name=self._collection.name, metadata={"hnsw:space": "cosine"}) + except Exception as e: + print(f"failed to clear stores: {e}") + + def size(self) -> int: + """get the number of stored texts""" + return self._collection.count() + + def get_collection_info(self) -> dict: + """get the collection info of ChromaDB""" + return { + "name": self._collection.name, + "count": self._collection.count(), + "metadata": self._collection.metadata, + } diff --git a/ajet/utils/env_service_client/__init__.py b/ajet/utils/env_service_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/astune/env_service_client/env_client_ng.py b/ajet/utils/env_service_client/env_client_ng.py similarity index 57% rename from astune/env_service_client/env_client_ng.py rename to ajet/utils/env_service_client/env_client_ng.py index c4e456f8..93af5d76 100644 --- a/astune/env_service_client/env_client_ng.py +++ b/ajet/utils/env_service_client/env_client_ng.py @@ -1,25 +1,16 @@ # env_client.py -from typing import Dict, List, Any, Optional, Callable -import requests -import time -import random import os -from datetime import datetime +import random +import tempfile +import time +from typing import Any, Callable, Dict, List, Optional + +import requests +from loguru import logger + +LOG_PATH = os.environ.get("CLIENT_LOG_PATH", os.path.join(tempfile.gettempdir(), "app_logs", "error.out")) -LOG_PATH = os.environ.get('CLIENT_LOG_PATH', "/mnt/data/eric.czq/rl_log/error.out") -# map_env_type = { -# 'appworld2': 'appworld', -# } -def safe_log(msg: str): - try: - os.makedirs(os.path.dirname(LOG_PATH), exist_ok=True) - with open(LOG_PATH, "a", encoding="utf-8") as f: - f.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}] {msg}\n") - f.flush() - os.fsync(f.fileno()) - except Exception: - pass # 防止日志写失败影响RL主进程 def retry_call( fn: Callable, @@ -28,39 +19,38 @@ def retry_call( max_backoff: float = 10.0, fail_return: Any = None, err_prefix: str = "", - instance_id: str = "", - action_name: str = "" + instance_id: str | None = "", + action_name: str = "", ): - last_exception = None for i in range(max_retry): try: res = fn() - if i>0: - safe_log(f"{err_prefix} {action_name} [instance={instance_id}] succeed at try {i+1}/{max_retry}") + if i > 0: + logger.info(f"{err_prefix} {action_name} [instance={instance_id}] succeed at try {i+1}/{max_retry}") return res except Exception as e: - last_exception = e - safe_log(f"{err_prefix} {action_name} [instance={instance_id}] retry {i+1}/{max_retry} failed: {e}") + logger.info(f"{err_prefix} {action_name} [instance={instance_id}] retry {i+1}/{max_retry} failed: {e}") if i + 1 == max_retry: - safe_log(f"{err_prefix} {action_name} [instance={instance_id}] max retries exceeded, fallback used.") - return fail_return + logger.exception(f"{err_prefix} {action_name} [instance={instance_id}] max retries exceeded, fallback used.") + raise RuntimeError("Env Service Timeout") wait = random.uniform(min_backoff, max_backoff) time.sleep(wait) return fail_return + class EnvClient: def __init__(self, base_url: str = "http://localhost:8000"): self.base_url = base_url.rstrip("/") - self.timeout = 150.0+random.uniform(50, 200) + self.timeout = 30.0 def _make_request( self, endpoint: str, env_type: str = "default", - task_id: str = None, - instance_id: str = None, - messages: Dict[str, Any] = None, - params: Dict[str, Any] = None, + task_id: str | None = None, + instance_id: str | None = None, + messages: Dict[str, Any] | None = None, + params: Dict[str, Any] | None = None, ) -> Dict: url = f"{self.base_url}/{endpoint.lstrip('/')}" @@ -77,22 +67,17 @@ def _make_request( response.raise_for_status() return response.json() except Exception as e: - safe_log( - f"[{endpoint}] _make_request failed (instance={instance_id}): {e}, data: {data}" - ) - raise Exception( - f"Request failed: {str(e)}, data: {data}" - ) + logger.exception(f"[{endpoint}] _make_request failed (instance={instance_id}): {e}, data: {data}") + raise Exception(f"Request failed: {str(e)}, data: {data}") def get_env_profile( self, env_type: str, split: str = "train", params: Optional[dict] = None, - max_retry: int = 3 + max_retry: int = 1, ) -> List[str]: def call(): - # 使用新的变量名,避免修改外部参数 # resolved_env_type = map_env_type.get(env_type, env_type) response = self._make_request( endpoint="/get_env_profile", @@ -111,11 +96,15 @@ def call(): max_retry=max_retry, fail_return=[], err_prefix="[get_env_profile]", - action_name="get_env_profile" + action_name="get_env_profile", ) def get_tools_info( - self, instance_id: str, messages: Dict = {}, params: Dict = {}, max_retry: int = 3 + self, + instance_id: str, + messages: Dict = {}, + params: Dict = {}, + max_retry: int = 3, ) -> Any: def call(): response = self._make_request( @@ -125,13 +114,14 @@ def call(): params=params, ) return response.get("data", None) + return retry_call( call, max_retry=max_retry, fail_return=None, - err_prefix=f"[get_tools_info]", + err_prefix="[get_tools_info]", instance_id=instance_id, - action_name="get_tools_info" + action_name="get_tools_info", ) def create_instance( @@ -140,15 +130,27 @@ def create_instance( task_id: str, instance_id: Optional[str] = None, params: Optional[Dict] = None, - max_retry: int = 3 + max_retry: int = 3, ) -> dict: fallback = { - "state": [{"role": "system", "content": "create query failed, this is a empty task."}, - {"role": "user", "content": "create failed, this is a empty task,please close this task."}], + "state": [ + { + "role": "system", + "content": "create query failed, this is a empty task.", + }, + { + "role": "user", + "content": "create failed, this is a empty task,please close this task.", + }, + ], "reward": 0, "is_terminated": False, - "info": {"instance_id": instance_id or "", "task_id": task_id or ""}, + "info": { + "instance_id": instance_id or "", + "task_id": task_id or "", + }, } + def call(): # if env_type in map_env_type: env_type = map_env_type[env_type] r = self._make_request( @@ -159,13 +161,14 @@ def call(): params=params, ) return r["data"] + return retry_call( call, max_retry=max_retry, fail_return=fallback, - err_prefix=f"[create_instance]", + err_prefix="[create_instance]", instance_id=instance_id, - action_name="create_instance" + action_name="create_instance", ) def step( @@ -176,28 +179,35 @@ def step( max_retry: int = 3, ) -> dict: fallback = { - "state": [{"role": "assistant", "content": "Step failed (timeout or exception),please retry"}], + "state": [ + { + "role": "assistant", + "content": "Step failed (timeout or exception),please retry", + } + ], "reward": 0, "is_terminated": False, "info": {"instance_id": instance_id or "", "task_id": ""}, } + def call(): resp = self._make_request( endpoint="step", instance_id=instance_id, messages=action, - params=params + params=params, ) return resp["data"] + res = retry_call( call, max_retry=max_retry, fail_return=fallback, - err_prefix=f"[step]", + err_prefix="[step]", instance_id=instance_id, - action_name="step" + action_name="step", ) - res['state'] = res['state'][0] + res["state"] = res["state"][0] return res def evaluate( @@ -215,61 +225,26 @@ def call(): params=params, ) return resp.get("data", 0.0) + return retry_call( call, max_retry=max_retry, fail_return=0.0, - err_prefix=f"[evaluate]", + err_prefix="[evaluate]", instance_id=instance_id, - action_name="evaluate" + action_name="evaluate", ) def release_instance(self, instance_id: str, max_retry: int = 3) -> bool: def call(): resp = self._make_request(endpoint="release", instance_id=instance_id) return resp.get("success", False) + return retry_call( call, max_retry=max_retry, fail_return=False, - err_prefix=f"[release_instance]", + err_prefix="[release_instance]", instance_id=instance_id, - action_name="release_instance" + action_name="release_instance", ) - -# 使用示例 -def main(): - client = EnvClient() - env_type = "appworld" - - # 获取任务列表 - task_ids = client.get_env_profile(env_type) - print(f"Available tasks: {task_ids}") - - # 创建实例 - task_id = task_ids[0] if task_ids else None - if not task_id: - print("任务列表为空,无法创建实例!") - return - init_response = client.create_instance(env_type, task_id) - print("init state", init_response) - instance_id = init_response["info"]["instance_id"] - query = init_response.get("state", []) - print(f"Created instance {instance_id} with query: {query}") - - # 执行动作 - action = {"role": "assistant", "content": "print('hello appworld!!')"} - result = client.step(instance_id, action) - print(f"Step result: {result}") - - # 评估 - score = client.evaluate(instance_id) - print(f"Evaluation score: {score}") - - # 释放实例 - success = client.release_instance(instance_id) - print(f"Instance released: {success}") - - -if __name__ == "__main__": - main() diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py new file mode 100644 index 00000000..c5d9b30b --- /dev/null +++ b/ajet/utils/launch_utils.py @@ -0,0 +1,241 @@ +import logging +import os +import shutil +import subprocess +import sys +import time + +from beast_logger import print_dict +from loguru import logger + +from ajet.utils.config_utils import align_parameters +from ajet.utils.smart_daemon import LaunchCommandWhenAbsent + + +def set_loguru_default_color(): + logger.remove() + colorize = os.environ.get("LOGURU_COLORIZE", "YES").upper() not in ["NO", "0", "FALSE"] + logger.add(sys.stderr, colorize=colorize, enqueue=False) + if not colorize: + os.environ["RAY_COLOR_PREFIX"] = "0" + + logging.getLogger("vllm.entrypoints.openai.tool_parsers.hermes_tool_parser").setLevel(logging.CRITICAL) + return + + +def launch_logview(exp_name=None): + """ + Launch the log viewer service and open the web browser to view logs. + + Args: + exp_name: Optional experiment name. If not provided, "default_experiment" is used. + """ + companion = LaunchCommandWhenAbsent( + full_argument_list=[ + sys.executable, + "-m", + "web_display.start_web", + ], + dir="./", + tag="logview", + ) + companion.launch( + launch_wait_time=1800, + success_std_string="Uvicorn running on", + env_dict={}, + ) + try: + import webbrowser + + time.sleep(2.5) + webbrowser.open("http://127.0.0.1:8181/") + except Exception as e: + logger.error(f"Error opening web browser: {e}") + + +def start_ray_service(args, env, cluster=False): + """ + Start a Ray service with appropriate configuration. + + Args: + args: Command line arguments containing debug settings + """ + # Get the current Python interpreter directory + python_dir = os.path.dirname(sys.executable) + ray_path = os.path.join(python_dir, "ray") + if not cluster: + companion = LaunchCommandWhenAbsent( + full_argument_list=[f"{ray_path} start --head --block"], + dir="./", + tag="ray_service", + use_pty=True, + ) + launch_wait_time = 600 + success_std_string = "Ray runtime started" + else: + HOSTNAME = os.uname().nodename + MASTER_ADDR = os.getenv("MASTER_ADDR") + MASTER_PORT = os.getenv("MASTER_PORT") + if HOSTNAME == MASTER_ADDR: + companion = LaunchCommandWhenAbsent( + full_argument_list=[f"{ray_path} start --head --node-ip-address={MASTER_ADDR} --port={MASTER_PORT} --disable-usage-stats --block"], + dir="./", + tag="ray_service_head", + use_pty=True, + ) + launch_wait_time = 600 + success_std_string = "Ray runtime started" + else: + companion = LaunchCommandWhenAbsent( + full_argument_list=[f"{ray_path} start --address={MASTER_ADDR}:{MASTER_PORT} --disable-usage-stats --block"], + dir="./", + tag="ray_service_worker", + use_pty=True, + ) + launch_wait_time = 9999999999 + # success_std_string = "Connected to Ray cluster" + success_std_string = "Just wait here forever" + companion.launch( + launch_wait_time=launch_wait_time, + success_std_string=success_std_string, + env_dict=env, + ) + + +def verify_python_env(args, exp_config): + """ + Verify that the current Python environment matches the expected executable. + + Args: + args: Command line arguments containing the expected python_executable + """ + if exp_config["ajet"]["trainer_common"]["logger"] == "swanlab": + if os.environ.get("SWANLAB_API_KEY", "") == "": + cause = "SWANLAB_API_KEY is not set in the environment." + solution = "To use the swanlab logger, please set `SWANLAB_API_KEY`. Otherwise, set `ajet.trainer_common.logger=tensorboard`" + print_dict( + { + "Python Environment Check": "FAILED", + "Cause": cause, + "Solution": solution, + } + ) + time.sleep(5) + raise ImportError(cause + " " + solution) + + import verl + + if args.backbone == "trinity": + if any([v in verl.__version__ for v in ["0.5.0.post", "0.7.0.post"]]): + cause = "Python environment does not match current backbone 'trinity'." + solution = "Please `cd /path/to/project/AgentJet` and run `(uv) pip install -e .[trinity]` to install the correct environment." + print_dict( + { + "Python Environment Check": "FAILED", + "Cause": cause, + "Solution": solution, + } + ) + time.sleep(5) + raise ImportError(cause + " " + solution) + elif args.backbone == "verl": + if not any([v in verl.__version__ for v in ["0.5.0.post", "0.5.0.dev", "0.7.0.post"]]): # you must install via `pip install -e .[verl]` to get every dependency right + cause = "Python environment does not match current backbone 'verl'." + solution = "Please `cd /path/to/project/AgentJet` and run `(uv) pip install -e .[verl]` to install the correct environment." + print_dict( + { + "Python Environment Check": "FAILED", + "Cause": cause, + "Solution": solution, + } + ) + time.sleep(5) + raise ImportError(cause + " " + solution) + + +def execute_training_process( + args, + backbone_target, + yaml_backup_dst, + exe_exp_base, + exe_yaml_path, + env, + exp_config, +): + """ + Execute the training process based on the specified backbone and configuration. + + Args: + args: Command line arguments + backbone_target: The Python module to execute + yaml_backup_dst: Path to the YAML configuration backup + exe_exp_base: Base path for experiment execution + exe_yaml_path: Path to the YAML configuration file + env: Environment variables dictionary + """ + + # Fixed config asset locations + TRINITY_BOOT_YAML = "ajet/default_config/trinity/trinity_launch.yaml" # THIS FILE IS READ ONLY, and ALWAYS FIXED + TRINITY_CONFIG_AUTO_CONVERSION = "ajet/default_config/trinity/config_auto_convertion_trinity.jsonc" + VERL_CONFIG_AUTO_CONVERSION = "ajet/default_config/verl/config_auto_convertion_verl.jsonc" + + os.makedirs("/tmp/ajet", exist_ok=True) + assert os.path.exists("/tmp/ajet"), "Temporary directory /tmp/ajet cannot be create." + + # let's begin the training process + if args.backbone == "trinity": + # replace boot yaml + redirect_trinity_boot_yaml = os.path.dirname(yaml_backup_dst) + "/trinity_launch.yaml" + shutil.copyfile(TRINITY_BOOT_YAML, redirect_trinity_boot_yaml) + align_parameters( + yaml_backup_dst, + redirect_trinity_boot_yaml, + TRINITY_CONFIG_AUTO_CONVERSION, + args.backbone, + ) + cmd = [ + sys.executable, + "-m", + backbone_target, + "run", + "--config", + redirect_trinity_boot_yaml, + ] + else: + align_parameters( + yaml_backup_dst, + yaml_backup_dst, + VERL_CONFIG_AUTO_CONVERSION, + args.backbone, + ) + cmd = [ + sys.executable, + "-m", + backbone_target, + "--config-path", + os.path.abspath(exe_exp_base), + "--config-name", + os.path.basename(exe_yaml_path), + ] + + if args.with_logview: + env.update({"BEST_LOGGER_WEB_SERVICE_URL": os.environ.get("BEST_LOGGER_WEB_SERVICE_URL", "http://127.0.0.1:8181/")}) + + try: + logger.info(f"Running command: {' '.join(cmd)}") + print_dict( + { + "Running Command": " ".join(cmd), + "Experiment Base": exe_exp_base, + "YAML Config": exe_yaml_path, + }, + header="Final Training Command & Directory", + ) + verify_python_env(args, exp_config) + subprocess.run(cmd, check=True, cwd=os.path.abspath("./"), env=env) + except subprocess.CalledProcessError as e: + logger.error(f"Error running subprocess: {e}") + sys.exit(1) + except Exception as e: + logger.error(f"Unexpected error: {e}") + sys.exit(1) diff --git a/ajet/utils/lowlevel_hook.py b/ajet/utils/lowlevel_hook.py new file mode 100644 index 00000000..7fa2e4b9 --- /dev/null +++ b/ajet/utils/lowlevel_hook.py @@ -0,0 +1,49 @@ +# def patch_task_creation(): +# pass + + +import inspect +import asyncio +from functools import wraps + + +def patch_task_creation(): + # Hook asyncio.create_task + original_create_task = asyncio.create_task + + @wraps(original_create_task) + def debug_create_task(coro, name=None, **kwargs): + if not name: + caller = inspect.stack()[1] + coro_name = getattr(coro, "__name__", str(coro)) + name = f"DEBUG_{coro_name}_at_{caller.filename.split('/')[-1]}:{caller.lineno}" + print(f"🎯 asyncio.create_task: {name}") + return original_create_task(coro, name=name, **kwargs) + + # Hook loop.create_task + original_loop_create_task = asyncio.AbstractEventLoop.create_task + + def debug_loop_create_task(self, coro, name=None, context=None): + if not name: + caller = inspect.stack()[1] + coro_name = getattr(coro, "__name__", str(coro)) + name = f"DEBUG_{coro_name}_at_{caller.filename.split('/')[-1]}:{caller.lineno}" + print(f"🎯 loop.create_task: {name}") + return original_loop_create_task(self, coro, name=name, context=context) + + # Hook Task.__init__ + original_task_init = asyncio.Task.__init__ + + def debug_task_init(self, coro, loop=None, name=None, context=None): + if not name: + caller = inspect.stack()[2] # 需要往上找更远的调用者 + coro_name = getattr(coro, "__name__", str(coro)) + name = f"DEBUG_{coro_name}_at_{caller.filename.split('/')[-1]}:{caller.lineno}" + print(f"🎯 Task.__init__: {name}") + return original_task_init(self, coro, loop=loop, name=name, context=context) + + asyncio.create_task = debug_create_task + asyncio.AbstractEventLoop.create_task = debug_loop_create_task + + +patch_task_creation() diff --git a/ajet/utils/magic_mock.py b/ajet/utils/magic_mock.py new file mode 100644 index 00000000..baeb6bca --- /dev/null +++ b/ajet/utils/magic_mock.py @@ -0,0 +1,21 @@ +class SpecialMagicMock(object): + def __init__(self, allowed_attributes=[]): + # Use __dict__ to avoid triggering __setattr__ + self.__dict__["allowed_attributes"] = allowed_attributes + self.__dict__["attr_store"] = {} + + def __getattr__(self, name): + if name in self.allowed_attributes: + return self.attr_store.get(name) + else: + raise ValueError(f"Attribute {name} is not allowed.") + + def __setattr__(self, name, value): + if name in self.allowed_attributes: + # Use __dict__ to avoid recursion + self.__dict__["attr_store"][name] = value + elif name in ("allowed_attributes", "attr_store"): + # Allow setting internal attributes directly + self.__dict__[name] = value + else: + raise ValueError(f"Attribute {name} is not allowed.") diff --git a/ajet/utils/message_utils.py b/ajet/utils/message_utils.py new file mode 100644 index 00000000..1c6abcad --- /dev/null +++ b/ajet/utils/message_utils.py @@ -0,0 +1,23 @@ +import copy +from typing import Dict, List + + +# apply chat_template to a message, and then convert back to message +def convert_tool_to_user_message(tool_message, tokenizer, format="qwen"): + assert format == "qwen" + + if tool_message["role"] == "user": + return tool_message + elif tool_message["role"] == "tool" and len(tool_message["tool_calls"]) > 0: + assert len(tool_message["tool_calls"]) == 1 + return { + "role": "user", + "content": str(tool_message["tool_calls"][0]["result"]), + } + + +def remove_fields(d: Dict, fields: List[str]) -> Dict: + d = copy.deepcopy(d) + for field in fields: + d.pop(field.strip(), None) + return d diff --git a/ajet/utils/metric_helper/__init__.py b/ajet/utils/metric_helper/__init__.py new file mode 100644 index 00000000..dc518016 --- /dev/null +++ b/ajet/utils/metric_helper/__init__.py @@ -0,0 +1,18 @@ +from ajet.utils.metric_helper.save_trajectory_as_json import save_trajectory_as_json +from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_trajectories +from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_trajectories + + +def save_trajectory_as_json_file(ctx_trackers, global_steps, config, prefix): + if config.ajet.trainer_common.save_trajectory_as_json_file: + save_trajectory_as_json(ctx_trackers, global_steps, prefix) + + +def update_metrics(context_tracker_arr, metrics: dict): + tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr) + reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr) + if tool_metrics: + metrics.update(tool_metrics) + if reward_metrics: + metrics.update(reward_metrics) + return diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py new file mode 100644 index 00000000..5a5b2303 --- /dev/null +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -0,0 +1,231 @@ +""" +FinWorld Reward Metrics Helper + +Provides standalone utility functions for reward_stats extraction and SwanLab metrics formatting. +Decouples finworld-specific logic from core code, reducing intrusion into native_compat_trainer. + +SwanLab metrics directory structure: +- rewards/ Top-level aggregated scores +- rewards/dimensions/ Raw scores (unweighted) +- rewards/contribution/ Weighted contributions +- judge_time/ Judge time consumption statistics +""" + +from typing import List, Dict, Any, Optional +import numpy as np + + +def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]: + """ + Extract reward_stats from trajectories list. + + Args: + trajectories: List of trajectory objects containing workflow_metadata + + Returns: + List of reward_stats dictionaries + """ + reward_stats_list = [] + for traj in trajectories: + if hasattr(traj, "workflow_metadata") and traj.workflow_metadata: + if "reward_stats" in traj.workflow_metadata: + reward_stats_list.append(traj.workflow_metadata["reward_stats"]) + return reward_stats_list + + +def extract_reward_stats_from_cmts(cmts: List[Any]) -> tuple[List[Dict[str, Any]], Dict[str, int]]: + """ + Extract reward_stats from cmts list and return debug statistics. + + Args: + cmts: List of cmt objects containing workflow_metadata + + Returns: + Tuple of (reward_stats_list, debug_stats) + """ + reward_stats_list = [] + debug_stats = { + "total_cmts": len(cmts), + "has_workflow_metadata": 0, + "has_reward_stats": 0, + } + + for _cmt in cmts: + if hasattr(_cmt, "workflow_metadata") and _cmt.workflow_metadata: + debug_stats["has_workflow_metadata"] += 1 + if "reward_stats" in _cmt.workflow_metadata: + debug_stats["has_reward_stats"] += 1 + reward_stats_list.append(_cmt.workflow_metadata["reward_stats"]) + + return reward_stats_list, debug_stats + + +def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: + """ + Compute SwanLab metrics from reward_stats list. + + Supports two data sources: + 1. RM Gallery RewardStats fields (rm_raw, etc.) + 2. OpenJudge fields (openjudge_xxx_raw, openjudge_xxx_contribution, etc.) + + Args: + reward_stats_list: List of reward_stats dictionaries + prefix: Metric name prefix (e.g., "val/" for validation phase) + + Returns: + Formatted metrics dictionary ready for SwanLab reporting + """ + if not reward_stats_list: + return {} + + n = len(reward_stats_list) + metrics = {} + + # ========== Top-level Scores (General) ========== + final_reward_list = [rs.get("final_reward", 0.0) for rs in reward_stats_list] + fused_reward_list = [rs.get("fused_reward", 0.0) for rs in reward_stats_list] + penalty_list = [rs.get("penalty", 0.0) for rs in reward_stats_list] + step_reward_list = [rs.get("step_reward", 0.0) for rs in reward_stats_list] + + # Penalty statistics + non_zero_penalties = [p for p in penalty_list if p != 0.0] + + # Top-level metrics + metrics[f"{prefix}rewards/final_reward_mean"] = float(np.mean(final_reward_list)) + metrics[f"{prefix}rewards/fused_reward_mean"] = float(np.mean(fused_reward_list)) + metrics[f"{prefix}rewards/penalty_mean"] = float(np.mean(penalty_list)) + metrics[f"{prefix}rewards/step_reward_mean"] = float(np.mean(step_reward_list)) + metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties) + metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0 + + # ========== Detect OpenJudge Usage ========== + openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get("openjudge_enabled", False)) + + if openjudge_enabled_count > 0: + # ========== OpenJudge Metrics ========== + metrics[f"{prefix}rewards/openjudge_enabled_rate"] = openjudge_enabled_count / n * 100 + + # Dynamically extract OpenJudge grader fields + # Currently supported graders: report_resolution, trajectory_faithfulness, + # rubrics_performance, trajectory_comprehensive, information_gain, action_loop + openjudge_graders = [ + "report_resolution", + "trajectory_faithfulness", + "rubrics_performance", + "trajectory_comprehensive", + "information_gain", + "action_loop", + ] + + for grader_name in openjudge_graders: + raw_key = f"openjudge_{grader_name}_raw" + contrib_key = f"openjudge_{grader_name}_contribution" + + raw_list = [rs.get(raw_key, 0.0) for rs in reward_stats_list] + contrib_list = [rs.get(contrib_key, 0.0) for rs in reward_stats_list] + + # Only report when non-zero values exist + if any(v != 0.0 for v in raw_list): + metrics[f"{prefix}rewards/openjudge/{grader_name}_raw_mean"] = float(np.mean(raw_list)) + if any(v != 0.0 for v in contrib_list): + metrics[f"{prefix}rewards/openjudge/{grader_name}_contribution_mean"] = float(np.mean(contrib_list)) + + # OpenJudge time consumption statistics + grading_time_list = [rs.get("grading_time", 0.0) for rs in reward_stats_list] + if any(v != 0.0 for v in grading_time_list): + metrics[f"{prefix}judge_time/openjudge_grading_time_mean"] = float(np.mean(grading_time_list)) + metrics[f"{prefix}judge_time/openjudge_grading_time_max"] = float(np.max(grading_time_list)) + + # ========== RM Gallery Metrics ========== + + # RM Gallery + rm_raw_list = [rs.get("rm_raw", 0.0) for rs in reward_stats_list] + rm_contribution_list = [rs.get("rm_contribution", 0.0) for rs in reward_stats_list] + + # RefJudge + ref_final_raw_list = [rs.get("ref_final_raw", 0.0) for rs in reward_stats_list] + ref_citation_raw_list = [rs.get("ref_citation_raw", 0.0) for rs in reward_stats_list] + ref_grounding_raw_list = [rs.get("ref_grounding_raw", 0.0) for rs in reward_stats_list] + ref_contribution_list = [rs.get("ref_contribution", 0.0) for rs in reward_stats_list] + + # StructureJudge + structure_raw_list = [rs.get("structure_raw", 0.0) for rs in reward_stats_list] + structure_contribution_list = [rs.get("structure_contribution", 0.0) for rs in reward_stats_list] + + # dimensions/ raw scores + metrics[f"{prefix}rewards/dimensions/rm_raw_mean"] = float(np.mean(rm_raw_list)) + metrics[f"{prefix}rewards/dimensions/ref_final_raw_mean"] = float(np.mean(ref_final_raw_list)) + metrics[f"{prefix}rewards/dimensions/ref_citation_raw_mean"] = float(np.mean(ref_citation_raw_list)) + metrics[f"{prefix}rewards/dimensions/ref_grounding_raw_mean"] = float(np.mean(ref_grounding_raw_list)) + metrics[f"{prefix}rewards/dimensions/structure_raw_mean"] = float(np.mean(structure_raw_list)) + + # contribution/ weighted contributions + metrics[f"{prefix}rewards/contribution/rm_contribution_mean"] = float(np.mean(rm_contribution_list)) + metrics[f"{prefix}rewards/contribution/ref_contribution_mean"] = float(np.mean(ref_contribution_list)) + metrics[f"{prefix}rewards/contribution/structure_contribution_mean"] = float(np.mean(structure_contribution_list)) + + # Enabled state statistics + ref_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get("ref_judge_enabled", False)) + if ref_judge_enabled_count > 0: + metrics[f"{prefix}rewards/ref_judge_enabled_rate"] = ref_judge_enabled_count / n * 100 + + structure_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get("structure_judge_enabled", False)) + if structure_judge_enabled_count > 0: + metrics[f"{prefix}rewards/structure_judge_enabled_rate"] = structure_judge_enabled_count / n * 100 + + # Time consumption statistics + rm_time_list = [rs.get("rm_time", 0.0) for rs in reward_stats_list] + refstruc_time_list = [rs.get("refstruc_time", 0.0) for rs in reward_stats_list] + + metrics[f"{prefix}judge_time/rm_time_mean"] = float(np.mean(rm_time_list)) + metrics[f"{prefix}judge_time/refstruc_time_mean"] = float(np.mean(refstruc_time_list)) + + if rm_time_list: + metrics[f"{prefix}judge_time/rm_time_max"] = float(np.max(rm_time_list)) + if refstruc_time_list: + metrics[f"{prefix}judge_time/refstruc_time_max"] = float(np.max(refstruc_time_list)) + + # ========== General Time Consumption Statistics ========== + judge_total_time_list = [rs.get("judge_total_time", 0.0) for rs in reward_stats_list] + if any(v != 0.0 for v in judge_total_time_list): + metrics[f"{prefix}judge_time/judge_total_time_mean"] = float(np.mean(judge_total_time_list)) + metrics[f"{prefix}judge_time/judge_total_time_max"] = float(np.max(judge_total_time_list)) + + return metrics + + +def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: + """ + Training phase: Extract reward_stats from trajectories and compute metrics. + + Args: + trajectories: List of trajectory objects + + Returns: + Formatted metrics dictionary + """ + reward_stats_list = extract_reward_stats_from_trajectories(trajectories) + return compute_reward_metrics(reward_stats_list, prefix="train_") + + +def compute_reward_metrics_from_cmts(cmts: List[Any], print_debug: bool = True) -> Dict[str, float]: + """ + Validation phase: Extract reward_stats from cmts and compute metrics. + + Args: + cmts: List of cmt objects + print_debug: Whether to print debug information + + Returns: + Formatted metrics dictionary (with "val_reward/" prefix) + """ + reward_stats_list, debug_stats = extract_reward_stats_from_cmts(cmts) + + if print_debug: + print(f"\n[DEBUG eval_dataset()] reward_stats statistics:") + print(f" - Total cmts count: {debug_stats['total_cmts']}") + print(f" - Has workflow_metadata: {debug_stats['has_workflow_metadata']}") + print(f" - Has reward_stats: {debug_stats['has_reward_stats']}") + print(f" - Extracted samples count: {len(reward_stats_list)}") + + return compute_reward_metrics(reward_stats_list, prefix="val_") diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py new file mode 100644 index 00000000..c9bb16b0 --- /dev/null +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -0,0 +1,46 @@ +import os +import json +from ajet.utils.msg_converter import convert_grouped_steps_to_openai_format + + +def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): + """ + Save ctx_trackers to JSON files for either training or evaluation. + + Args: + ctx_trackers (list): List of context trackers containing trajectory data. + global_steps (int): The global step count to organize saved files. + prefix (str): Directory prefix indicating the type of trajectory ("train" or "eval"). + """ + for ctx_tracker in ctx_trackers: + # Determine task tag based on reward + reward = ctx_tracker.reward_structure.raw_reward + if reward >= 1: + ctx_tracker.tag = "success" + elif reward == 0: + ctx_tracker.tag = "failure" + else: + ctx_tracker.tag = "half_success" + + formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.timeline_cache) + + # Prepare trajectory data + traj_data = {"task_id": ctx_tracker.task_id, "task_tag": ctx_tracker.tag, "reward_structure": ctx_tracker.reward_structure.model_dump(), "traj": formatted_traj} + + # Extract reward_stats from workflow_metadata if available + if hasattr(ctx_tracker, "workflow_metadata") and ctx_tracker.workflow_metadata: + if "reward_stats" in ctx_tracker.workflow_metadata: + traj_data["reward_structure"]["reward_stats"] = ctx_tracker.workflow_metadata["reward_stats"] + + # Define save directory and file path + traj_save_dir = os.path.join(os.environ.get("BEST_LOGGER_PATH", "launcher_record"), "ctx_trackers", prefix, f"step_{global_steps}") + os.makedirs(traj_save_dir, exist_ok=True) + traj_file_path = os.path.join(traj_save_dir, f"{ctx_tracker.task_id}.json") + + # Save trajectory data to JSON file + with open(traj_file_path, "w", encoding="utf-8") as f: + json.dump(traj_data, f, ensure_ascii=False, indent=2) + + # Print confirmation for evaluation trajectories + if prefix != "train": + print(f"Saved trajectory to {traj_file_path}") diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py new file mode 100644 index 00000000..b58aad03 --- /dev/null +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -0,0 +1,169 @@ +""" +FinWorld Tool Metrics Helper + +Specialized module for extracting tool-related statistics and formatting SwanLab reports. +Extracts data from workflow_metadata['tool_stats']. + +SwanLab metrics directory structure: +- tool_stats/ Overall statistics (success rate, cache hit rate, etc.) +- tool_time/ Time consumption statistics by tool +- tool_cache/ Cache hit rate by tool +- tool_error/ Error rate by tool +""" + +from typing import List, Dict, Any +import numpy as np + + +def extract_tool_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]: + """ + Extract tool_stats from trajectories list. + + Args: + trajectories: List of trajectory objects containing workflow_metadata + + Returns: + List of tool_stats dictionaries + """ + tool_stats_list = [] + for traj in trajectories: + if hasattr(traj, "workflow_metadata") and traj.workflow_metadata: + if "tool_stats" in traj.workflow_metadata: + tool_stats_list.append(traj.workflow_metadata["tool_stats"]) + return tool_stats_list + + +def extract_tool_stats_from_cmts(cmts: List[Any]) -> List[Dict[str, Any]]: + """ + Extract tool_stats from cmts list. + + Args: + cmts: List of cmt objects containing workflow_metadata + + Returns: + List of tool_stats dictionaries + """ + tool_stats_list = [] + for traj in trajs: + if hasattr(traj, "workflow_metadata") and traj.workflow_metadata: + if "tool_stats" in traj.workflow_metadata: + tool_stats_list.append(traj.workflow_metadata["tool_stats"]) + return tool_stats_list + + +def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: + """ + Compute SwanLab metrics from tool_stats list. + + Args: + tool_stats_list: List of tool_stats dictionaries + prefix: Metric name prefix (e.g., "val/" for validation phase) + + Returns: + Formatted metrics dictionary ready for SwanLab reporting + """ + if not tool_stats_list: + return {} + + metrics = {} + + # ========== 1. Overall Statistics ========== + total_calls_list = [stats.get("total_calls", 0) for stats in tool_stats_list] + success_calls_list = [stats.get("success_calls", 0) for stats in tool_stats_list] + error_calls_list = [stats.get("total_errors", 0) for stats in tool_stats_list] + cache_hits_list = [stats.get("cache_hits", 0) for stats in tool_stats_list] + cache_misses_list = [stats.get("cache_misses", 0) for stats in tool_stats_list] + + # Calculate overall success rate + total_calls_sum = sum(total_calls_list) + success_calls_sum = sum(success_calls_list) + tool_success_rate = (success_calls_sum / total_calls_sum * 100) if total_calls_sum > 0 else 0.0 + + # Calculate overall cache hit rate + cache_total = sum(cache_hits_list) + sum(cache_misses_list) + cache_hit_rate = (sum(cache_hits_list) / cache_total * 100) if cache_total > 0 else 0.0 + + metrics.update( + { + f"{prefix}tool_stats/tool_success_rate": tool_success_rate, + f"{prefix}tool_stats/tool_total_calls": float(np.mean(total_calls_list)), + f"{prefix}tool_stats/tool_success_calls": float(np.mean(success_calls_list)), + f"{prefix}tool_stats/tool_error_calls": float(np.mean(error_calls_list)), + f"{prefix}tool_stats/tool_cache_hit_rate": cache_hit_rate, + f"{prefix}tool_stats/tool_cache_hits": float(np.mean(cache_hits_list)), + f"{prefix}tool_stats/tool_cache_misses": float(np.mean(cache_misses_list)), + } + ) + + # ========== 2. Time Consumption Statistics by Tool ========== + tool_time_by_name = {} + for stats in tool_stats_list: + tool_time_dict = stats.get("tool_time", {}) + for tool_name, time_list in tool_time_dict.items(): + if tool_name not in tool_time_by_name: + tool_time_by_name[tool_name] = [] + if isinstance(time_list, list): + tool_time_by_name[tool_name].extend(time_list) + + for tool_name, time_list in tool_time_by_name.items(): + if time_list: + metrics[f"{prefix}tool_time/{tool_name}/mean"] = float(np.mean(time_list)) + metrics[f"{prefix}tool_time/{tool_name}/max"] = float(np.max(time_list)) + metrics[f"{prefix}tool_time/{tool_name}/count"] = len(time_list) + + # ========== 3. Cache Hit Rate by Tool ========== + tool_cache_by_name = {} + for stats in tool_stats_list: + tool_cache_stats = stats.get("tool_cache_stats", {}) + for tool_name, cache_info in tool_cache_stats.items(): + if tool_name not in tool_cache_by_name: + tool_cache_by_name[tool_name] = {"hits": 0, "misses": 0} + tool_cache_by_name[tool_name]["hits"] += cache_info.get("hits", 0) + tool_cache_by_name[tool_name]["misses"] += cache_info.get("misses", 0) + + for tool_name, cache_info in tool_cache_by_name.items(): + hits = cache_info["hits"] + misses = cache_info["misses"] + total = hits + misses + if total > 0: + hit_rate = hits / total * 100 + metrics[f"{prefix}tool_cache/{tool_name}/hit_rate"] = round(hit_rate, 2) + metrics[f"{prefix}tool_cache/{tool_name}/hits"] = hits + metrics[f"{prefix}tool_cache/{tool_name}/misses"] = misses + + # ========== 4. Error Rate by Tool ========== + tool_error_by_name = {} + for stats in tool_stats_list: + tool_error_stats = stats.get("tool_error_stats", {}) + for tool_name, error_info in tool_error_stats.items(): + if tool_name not in tool_error_by_name: + tool_error_by_name[tool_name] = {"calls": 0, "errors": 0} + tool_error_by_name[tool_name]["calls"] += error_info.get("calls", 0) + tool_error_by_name[tool_name]["errors"] += error_info.get("errors", 0) + + for tool_name, error_info in tool_error_by_name.items(): + calls = error_info["calls"] + errors = error_info["errors"] + if calls > 0: + error_rate = errors / calls * 100 + metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) + metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls + metrics[f"{prefix}tool_error/{tool_name}/errors"] = errors + + return metrics + + +def compute_tool_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: + """ + Training phase: Extract tool_stats from trajectories and compute metrics. + """ + tool_stats_list = extract_tool_stats_from_trajectories(trajectories) + return compute_tool_metrics(tool_stats_list, prefix="train_") + + +def compute_tool_metrics_from_cmts(cmts: List[Any]) -> Dict[str, float]: + """ + Validation phase: Extract tool_stats from cmts and compute metrics. + """ + tool_stats_list = extract_tool_stats_from_cmts(cmts) + return compute_tool_metrics(tool_stats_list, prefix="val_") diff --git a/ajet/utils/msg_converter.py b/ajet/utils/msg_converter.py new file mode 100644 index 00000000..de50d5fc --- /dev/null +++ b/ajet/utils/msg_converter.py @@ -0,0 +1,94 @@ +""" +Message format conversion utilities + +Provides bidirectional conversion between OpenAI format and AgentScope format. +Unified for both train and val phases. + +## OpenAI format examples: +- Assistant with tool_calls: + {"role": "assistant", "content": "...", "tool_calls": [{"id": "call_xxx", "type": "function", "function": {"name": "...", "arguments": "..."}}]} +- Tool result: + {"role": "tool", "content": "...", "tool_call_id": "call_xxx"} +- Normal message: + {"role": "user/assistant/system", "content": "..."} + +## AgentScope format examples: +- Assistant with tool_calls: + {"role": "assistant", "content": [{"type": "text", "text": "..."}, {"type": "tool_use", "id": "call_xxx", "name": "...", "input": {...}}]} +- Tool result: + {"role": "user", "content": [{"type": "tool_result", "id": "call_xxx", "output": "..."}]} +- Normal message: + {"role": "user/assistant/system", "content": "..."} +""" + +import json +from typing import List, Dict, Any, Union + + +# ============================================================================= +# ExtendedMessage -> OpenAI conversion (backward compatible functions) +# ============================================================================= + + +def convert_ext_msg_to_openai_format(ext_msg: Any) -> Dict[str, Any]: + """ + Convert a single ExtendedMessage or dict to OpenAI format message. + + Args: + ext_msg: ExtendedMessage object or dict + + Returns: + Message dict in OpenAI format + """ + + # Helper function: get attribute value + def get_attr(obj, attr_name, default=None): + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + elif isinstance(obj, dict): + return obj.get(attr_name, default) + return default + + # Check if there are tool_calls (assistant initiates tool call) + tool_calls = get_attr(ext_msg, "tool_calls") + has_tool_calls = bool(tool_calls) + + # Check if there's tool_call_id (tool return result) + tool_call_id = get_attr(ext_msg, "tool_call_id") + has_tool_call_id = bool(tool_call_id) + + # Get basic attributes + role = get_attr(ext_msg, "role", "user") + content = get_attr(ext_msg, "content", "") + + if has_tool_calls: + # Assistant message contains tool_calls -> keep OpenAI format + msg_dict = {"role": "assistant", "content": content if content else "", "tool_calls": tool_calls} + elif has_tool_call_id: + # Tool return result -> use OpenAI format (role: "tool") + msg_dict = {"role": "tool", "content": content if content else "", "tool_call_id": tool_call_id} + else: + # Normal message, keep original format + msg_dict = {"role": role, "content": content if content else ""} + + return msg_dict + + +def convert_grouped_steps_to_openai_format(timelines: List[List[Any]]) -> List[List[Dict[str, Any]]]: + """ + Convert timelines (multi-turn conversation steps) to OpenAI format. + + Args: + timelines: List of List of ExtendedMessage or dict + + Returns: + Trajectory data in OpenAI format (List of List of dict) + """ + formatted_traj = [] + for context in timelines: + step_msgs = [] + for ext_msg in context: + msg_dict = convert_ext_msg_to_openai_format(ext_msg) + step_msgs.append(msg_dict) + formatted_traj.append(step_msgs) + return formatted_traj diff --git a/ajet/utils/networking.py b/ajet/utils/networking.py new file mode 100644 index 00000000..38bac512 --- /dev/null +++ b/ajet/utils/networking.py @@ -0,0 +1,34 @@ +import socket + + +def find_free_port() -> int: + """Find a free port on the system.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + +def get_host_ip(interface=None): + """ + get the host machine's IP address. + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + if interface: + try: + import fcntl + import struct + + ip = socket.inet_ntoa(fcntl.ioctl(s.fileno(), 0x8915, struct.pack("256s", interface[:15].encode()))[20:24]) # SIOCGIFADDR + return ip + except (ImportError, IOError): + pass + s.connect(("8.8.8.8", 80)) + ip = s.getsockname()[0] + s.close() + return ip + + except Exception: + return "127.0.0.1" diff --git a/ajet/utils/process_dataset.py b/ajet/utils/process_dataset.py new file mode 100644 index 00000000..917b5053 --- /dev/null +++ b/ajet/utils/process_dataset.py @@ -0,0 +1,58 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Union + +import torch +from omegaconf import DictConfig +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import RandomSampler, SequentialSampler +from verl.experimental.dataset.sampler import AbstractSampler +from verl.utils.import_utils import load_extern_type + + +def create_rl_sampler(data_config: DictConfig, dataset: TorchDataset) -> Union[RandomSampler, SequentialSampler, AbstractSampler]: + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_type( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, "If using curriculum, num_workers must be 0 to prevent data caching. " "If the dataloader caches data before the batch is done the " "curriculum sampler won't have the opportunity to reorder it. " + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(data_config.get("seed", int(time.time()))) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler diff --git a/ajet/utils/pty.py b/ajet/utils/pty.py new file mode 100644 index 00000000..2affd1fc --- /dev/null +++ b/ajet/utils/pty.py @@ -0,0 +1,145 @@ +import base64 +import os +import pty + + +def run_command_with_pty(cmd, working_dir, env_dict): + """ + Run a command in a pseudo-terminal (PTY) and stream output to stdout. + + Args: + cmd (list): Command to run (e.g., ["ls", "-l"]). + working_dir (str): Working directory. + env_dict (dict): Environment variables dictionary. + """ + # Save original environment and directory + original_env = os.environ.copy() + original_dir = os.getcwd() + + try: + # Change to the target working directory + os.chdir(working_dir) + + # Update environment variables + for key, value in env_dict.items(): + os.environ[key] = value + + # # Open a log file in append mode (optional) + # with open(log_file, 'a') as log_f: + + # Define master device read callback + def master_read(fd): + try: + # Read data from PTY master + data = os.read(fd, 1024) + except OSError: + return b"" + + if data: + # Write data to log file + # log_f.write(data.decode()) + # log_f.flush() + # Also print to stdout (optional) + # Use errors='replace' to handle incomplete UTF-8 sequences + print(data.decode(errors="replace"), end="") + return data + + # Define stdin read callback + def stdin_read(fd): + # Return empty bytes if no stdin input is needed + return b"" + + # Spawn a PTY and run the command + pty.spawn(cmd, master_read, stdin_read) + + finally: + # Restore original working directory + os.chdir(original_dir) + + # Restore original environment variables + os.environ.clear() + os.environ.update(original_env) + + +# Convert string to Base64 +def string_to_base64(s): + # First, encode the string to bytes + s_bytes = s.encode("utf-8") + # Convert bytes to base64 + base64_bytes = base64.b64encode(s_bytes) + # Convert base64 bytes back to string + base64_string = base64_bytes.decode("utf-8") + return base64_string + + +# Convert Base64 back to string +def base64_to_string(b): + # Convert base64 string to bytes + base64_bytes = b.encode("utf-8") + # Decode base64 bytes + message_bytes = base64.b64decode(base64_bytes) + # Convert bytes back to string + message = message_bytes.decode("utf-8") + return message + + +def pty_wrapper( + cmd: list[str], + dir: str, + env_dict: dict[str, str] = {}, +): + run_command_with_pty(cmd, working_dir=dir, env_dict=env_dict) + + +def pty_wrapper_final(human_cmd, dir, env_dict): + print("[pty]: ", human_cmd) + pty_wrapper(["/bin/bash", "-c", human_cmd], dir, env_dict) + + +def pty_launch(service_name: str, success_std_string="Starting server on"): + from ajet.utils.smart_daemon import LaunchCommandWhenAbsent + + service_path = os.environ.get(f"{service_name.upper()}_PATH") + service_script = os.environ.get(f"{service_name.upper()}_SCRIPT") + if service_path is None or service_script is None: + raise ValueError(f"Environment variables for {service_name} not properly set.") + companion = LaunchCommandWhenAbsent( + full_argument_list=[service_script], + dir=service_path, + tag=f"{service_name}_service", + use_pty=True, + ) + companion.launch( + launch_wait_time=3600, + success_std_string=success_std_string, + ) + + +if __name__ == "__main__": + import argparse + import json + + parser = argparse.ArgumentParser(description="Run a shell command in a PTY with logging and custom env.") + parser.add_argument("--human-cmd", type=str, help="Shell command to run (as a string)") + parser.add_argument("--dir", type=str, default=".", help="Working directory") + parser.add_argument( + "--env", + type=str, + default="{}", + help='Environment variables as JSON string, e.g. \'{"KEY":"VAL"}\'', + ) + + args = parser.parse_args() + + try: + env_dict = json.loads(args.env) + if not isinstance(env_dict, dict): + raise ValueError + except Exception: + print( + '--env must be a valid JSON object string, e.g. \'{"KEY":"VAL"}\'. But get:', + args.env, + ) + exit(1) + + pty_wrapper_final(base64_to_string(args.human_cmd), args.dir, env_dict) diff --git a/ajet/utils/retry.py b/ajet/utils/retry.py new file mode 100644 index 00000000..4fd84e0c --- /dev/null +++ b/ajet/utils/retry.py @@ -0,0 +1,50 @@ +import time +from functools import wraps +from typing import Any, Callable, Optional, TypeVar + +from loguru import logger + +from ajet.utils.testing_utils import TestFailException, TestSuccessException + +T = TypeVar("T") + + +def retry_with_backoff( + max_retry: int = 3, + backoff_fn: Optional[Callable[[int], float]] = None, + max_retry_attr: Optional[str] = None, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """Retry decorator with exponential backoff and structured logging.""" + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + target_max_retry = max_retry + if max_retry_attr and args: + candidate = getattr(args[0], max_retry_attr, None) + if isinstance(candidate, int) and candidate > 0: + target_max_retry = candidate + if target_max_retry < 1: + target_max_retry = 1 + + for attempt in range(target_max_retry): + try: + return func(*args, **kwargs) + except TestSuccessException as exc: # noqa: BLE001 + raise exc + except TestFailException as exc: # noqa: BLE001 + raise exc + except Exception as exc: # noqa: BLE001 + if attempt < target_max_retry - 1: + logger.bind(exception=True).exception(f"{func.__name__} error: {exc.args}, retrying {attempt + 1}/{target_max_retry}") + sleep_seconds = backoff_fn(attempt) if backoff_fn else 2**attempt + time.sleep(sleep_seconds) + else: + logger.bind(exception=True).exception(f"{func.__name__} failed after {target_max_retry} retries: {exc.args}") + raise + + raise RuntimeError("retry_with_backoff exhausted attempts") + + return wrapper + + return decorator diff --git a/ajet/utils/robust_dashscope.py b/ajet/utils/robust_dashscope.py new file mode 100644 index 00000000..3cfc90f6 --- /dev/null +++ b/ajet/utils/robust_dashscope.py @@ -0,0 +1,103 @@ +import os +import random +import time +from textwrap import dedent + +from agentscope.model import DashScopeChatModel +from loguru import logger + + +class RobustDashScopeChatModel(DashScopeChatModel): + """ + A robust version of DashScopeChatModel that includes retry logic and multiple API key handling. + This class extends the DashScopeChatModel from agentscope and adds: + 1. Support for multiple API keys separated by '|' in environment variables + 2. Automatic retry logic with backup API keys + 3. Error handling with appropriate logging + """ + + def __init__(self, model_name="qwen3-max", stream=False, max_try=4, **kwargs): + # Check for environment variables + self._check_env_variables() + + # Parse API keys from environment variables + self.regular_key_list = os.environ.get("DASHSCOPE_API_KEY", "").split("|") + self.backup_key_list = os.environ.get("DASHSCOPE_API_KEY_BACKUP", "").split("|") if os.environ.get("DASHSCOPE_API_KEY_BACKUP") else [] + + api_key = random.choice(self.regular_key_list) + + # Store retry parameters + self.max_try = max_try + + # Initialize the parent class + super().__init__(api_key=api_key, model_name=model_name, stream=stream, **kwargs) + + def _check_env_variables(self): + """Check if required environment variables are set.""" + if os.environ.get("DASHSCOPE_API_KEY") is None: + raise RuntimeError( + dedent( + """ + Please set the DASHSCOPE_API_KEY environment variable. + You can get the API keys from https://www.dashscope.com/. + Example: + export DASHSCOPE_API_KEY='sk-xxxxxx|sk-yyyyyy' + export DASHSCOPE_API_KEY_BACKUP='sk-zzzzzz' (optional) + """ + ) + ) + + async def __call__( + self, + messages, + tools=None, + tool_choice=None, + structured_model=None, + **kwargs, + ): + """ + Override the __call__ method to add retry logic and API key rotation. + + Args: + messages: The messages to send to the model + tools: Optional list of tools + tool_choice: Optional tool choice + structured_model: Optional structured model + **kwargs: Additional arguments to pass to the API + + Returns: + The response from the model + + Raises: + RuntimeError: If all retry attempts fail + """ + for n_try in range(self.max_try): + try: + # Select API key based on retry attempt + if n_try < self.max_try // 2: + # For first half of attempts, use regular keys + self.api_key = random.choice(self.regular_key_list) + elif n_try == self.max_try // 2 and self.backup_key_list: + # At middle attempt, try backup key if available + self.api_key = random.choice(self.backup_key_list) + else: + # For remaining attempts, use any available key + self.api_key = random.choice(self.regular_key_list + self.backup_key_list) + + # Call the parent class's __call__ method + response = await super().__call__( + messages=messages, + tools=tools, + tool_choice=tool_choice, + structured_model=structured_model, + **kwargs, + ) + return response + + except Exception as e: + logger.bind(exception=True).exception(f"Error calling DashScope API: {e}") + time.sleep(5) # Wait before retrying + print(f"Error calling DashScope API: {e}, retrying ({n_try + 1}/{self.max_try})...") + + # If all attempts fail + raise RuntimeError(f"Failed to get response from DashScope API after {self.max_try} attempts") diff --git a/ajet/utils/sample.py b/ajet/utils/sample.py new file mode 100644 index 00000000..0e2bbcb9 --- /dev/null +++ b/ajet/utils/sample.py @@ -0,0 +1,33 @@ +def get_sample_params(mode, config): + """ + Generate sampling parameters for text generation based on mode and config. + Args: + mode (str): The mode of operation, e.g., 'validate'. + config: Configuration object containing rollout parameters. + Returns: + dict: Sampling parameters for the model. + """ + response_length_eps = 16 # Reserve a few tokens for later handling of special tokens like lm_start. + if config.ajet.rollout.name == "vllm": + # VLLM uses max_tokens instead of max_new_tokens + sampling_params = dict( + n=1, + max_tokens=config.ajet.rollout.max_response_length_in_one_turn - response_length_eps, + min_tokens=1, # Must output at least 1 token. + temperature=config.ajet.rollout.temperature, + top_p=config.ajet.rollout.top_p, + logprobs=1, + ) + else: + sampling_params = dict( + n=1, + max_new_tokens=config.ajet.rollout.max_response_length_in_one_turn, + temperature=config.ajet.rollout.temperature, + top_p=config.ajet.rollout.top_p, + ) + + if mode == "validate": + sampling_params["temperature"] = config.ajet.rollout.val_kwargs.temperature + sampling_params["top_k"] = config.ajet.rollout.val_kwargs.top_k + sampling_params["top_p"] = config.ajet.rollout.val_kwargs.top_p + return sampling_params diff --git a/ajet/utils/sington.py b/ajet/utils/sington.py new file mode 100644 index 00000000..b46095c5 --- /dev/null +++ b/ajet/utils/sington.py @@ -0,0 +1,9 @@ +def singleton(cls): + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance diff --git a/astune/utils/smart_daemon.py b/ajet/utils/smart_daemon.py similarity index 50% rename from astune/utils/smart_daemon.py rename to ajet/utils/smart_daemon.py index 36b91bf7..b3702f34 100644 --- a/astune/utils/smart_daemon.py +++ b/ajet/utils/smart_daemon.py @@ -1,14 +1,35 @@ -import os -import sys -import psutil -import subprocess +import base64 import hashlib -import time import json import logging -from loguru import logger +import os +import subprocess +import sys +import time from pathlib import Path -from typing import Optional, Tuple, List +from typing import List, Optional, Tuple + +import psutil +from beast_logger import print_dict +from loguru import logger + + +def string_to_base64(s): + """ + Convert a string to its base64 encoded representation. + Args: + s (str): Input string. + Returns: + str: Base64 encoded string. + """ + # First, encode the string to bytes + s_bytes = s.encode("utf-8") + # Then convert bytes to base64 + base64_bytes = base64.b64encode(s_bytes) + # Finally, convert base64 bytes back to string + base64_string = base64_bytes.decode("utf-8") + return base64_string + class LaunchWhenAbsent: """ @@ -16,7 +37,15 @@ class LaunchWhenAbsent: If the script is already running, it will skip launching unless force_restart is True. """ - def __init__(self, script_path: str, argument_list: List[str] = None, exe: str = None, dir = None, tag='', use_pty=False): + def __init__( + self, + script_path: str, + argument_list: List[str] = None, + exe: str = None, + dir=None, + tag="", + use_pty=False, + ): """ Initialize with the path to the Python script to be launched. @@ -41,16 +70,16 @@ def __init__(self, script_path: str, argument_list: List[str] = None, exe: str = full_argument_list = [self.script_path] + self.argument_list hash_items = full_argument_list + [str(self.dir), str(exe)] - self.script_hash = hashlib.md5(''.join(hash_items).encode()).hexdigest()[:8] + self.script_hash = hashlib.md5("".join(hash_items).encode()).hexdigest()[:8] # Prepare command with hash ID marker if self.use_pty: assert len(full_argument_list) == 1 - self.cmd = [self.exe + " " + full_argument_list[0]] + self.cmd = [self.exe + " " + full_argument_list[0]] else: - self.cmd = ['nohup'] + [self.exe] + full_argument_list + self.cmd = ["nohup"] + [self.exe] + full_argument_list - log_dir = Path("launcher_record/companion_logs/companion") + log_dir = Path("saved_experiments/companion_logs/companion") log_dir.mkdir(parents=True, exist_ok=True) hostname = os.uname().nodename if tag: @@ -61,8 +90,9 @@ def __init__(self, script_path: str, argument_list: List[str] = None, exe: str = self.logger_file = log_dir / f"{base_log_name}.log" self.pgid = None - - def _is_script_running(self) -> Tuple[bool, Optional[psutil.Process], Optional[int]]: + def _is_script_running( + self, + ) -> Tuple[bool, Optional[psutil.Process], Optional[int]]: """ Check if the script is already running by looking for its unique hash ID in process command lines. @@ -76,7 +106,7 @@ def _is_script_running(self) -> Tuple[bool, Optional[psutil.Process], Optional[i if not self.pgid_file.exists(): return False, None, None else: - with open(self.pgid_file, 'r') as f_pgid: + with open(self.pgid_file, "r") as f_pgid: pgid = int(f_pgid.read().strip()) # Check if the process group ID is still running, if true, psutil is_running, proc = self.is_pgid_running(pgid) @@ -89,9 +119,22 @@ def _is_script_running(self) -> Tuple[bool, Optional[psutil.Process], Optional[i return False, None, None def is_pgid_running(self, pgid): - for proc in psutil.process_iter(['pid']): + """ + Check if a process group ID is running (ignoring zombies). + Args: + pgid (int): Process group ID. + Returns: + Tuple[bool, Optional[psutil.Process]]: (is_running, process_if_found) + """ + # Treat zombie processes as not running to avoid false positives. + for proc in psutil.process_iter(["pid", "status"]): try: + if proc.info.get("status") == psutil.STATUS_ZOMBIE: + continue if os.getpgid(proc.pid) == pgid: + # Double-check status to avoid races where the cached info is missing. + if proc.status() == psutil.STATUS_ZOMBIE: + continue return True, proc except (psutil.NoSuchProcess, ProcessLookupError): continue @@ -133,121 +176,193 @@ def _kill_existing_process_group(self, pgid: int): print(f"Cleaned up PGID file: {self.pgid_file}") def shutdown(self): + """ + Shutdown the process group if running. + """ if self.pgid: self._kill_existing_process_group(self.pgid) - def launch(self, force_restart: bool = False, launch_wait_time: int = 30, success_std_string: str = None, env_dict = {}): + def kill_self(self): + """Force terminate this launcher instance if it's running.""" + is_running, _, pgid = self._is_script_running() + if not is_running or pgid is None: + logger.info("No running process group found for this launcher") + return False + self.pgid = pgid + self._kill_existing_process_group(pgid) + return True + + def launch( + self, + force_restart: bool = False, + launch_wait_time: int = 30, + success_std_string: str | None | List[str] = None, + env_dict={}, + ) -> str: """ Launch the script if it's not running, or restart it if force_restart is True. Args: force_restart (bool): If True, kill existing process and restart launch_wait_time (int): Maximum time to wait for process launch in seconds - success_std_string (str): String to look for in stdout to confirm successful launch + success_std_string (str, List[str]): String to look for in stdout to confirm successful launch + env_dict (dict): Environment variables for the process + Returns: + str: Content that hit the success string, if any. """ is_running, existing_process, pgid = self._is_script_running() self.pgid = pgid + # convert to list to simplify later checks + if isinstance(success_std_string, str): + success_std_string_arr = [success_std_string] + else: + success_std_string_arr = success_std_string + hit_success_string_content = "" + if is_running: if force_restart: - logger.warning(f"Force restarting") + assert pgid is not None + logger.warning("Force restarting") self._kill_existing_process_group(pgid) else: logger.success(f"Script is already running, skipping launch. pgid: {pgid}. Command [{' '.join(self.cmd)}]") - return + return "" try: # Set up process creation flags and environment # Create logs directory - log_dir = Path("launcher_record/companion_logs/companion") + log_dir = Path("saved_experiments/companion_logs/companion") log_dir.mkdir(parents=True, exist_ok=True) # Open log file log_file = self.logger_file - if os.name == 'nt': # Windows + if os.name == "nt": # Windows # DETACHED_PROCESS flag - raise NotImplementedError("Windows support is not implemented yet.") - else: # Unix-like systems + raise NotImplementedError("Windows support is not implemented yet. Please open a feature request.") + + else: + # Unix-like systems # Use nohup and redirect output - logger.warning("\nlaunching: " + " ".join(self.cmd)) - logger.warning(f"\nlogging to {log_file}\n") # Open log file if log_file.exists(): os.remove(log_file) + if not self.use_pty: - f = open(log_file, 'a') + print_dict( + { + "Action": "Launching command", + "Command": " ".join(self.cmd), + "LogFile": str(log_file), + }, + header="Smart Daemon Launch", + ) + f = open(log_file, "a") + + # for key in ['COLORTERM', 'LS_COLORS', 'CLICOLOR', 'CLICOLOR_FORCE', 'FORCE_COLOR']: + # env_dict.pop(key, None) + env_dict.update( + { + "NO_COLOR": "1", + # 'TERM': 'dumb', + # 'PYTHONUNBUFFERED': '1', + "LOGURU_COLORIZE": "NO", + # 'LOGURU_AUTOINIT': '1', + } + ) + proc = subprocess.Popen( self.cmd, stdout=f, stderr=subprocess.STDOUT, stdin=subprocess.DEVNULL, cwd=self.dir, - env={'ScriptHash': self.script_hash, **env_dict}, - start_new_session=True # Start new session + env={"ScriptHash": self.script_hash, **env_dict}, + start_new_session=True, # Start new session ) f.close() # Close append handle pgid = os.getpgid(proc.pid) else: - import base64 - - def string_to_base64(s): - # 首先将字符串编码为字节 - s_bytes = s.encode('utf-8') - # 将字节转换为 base64 - base64_bytes = base64.b64encode(s_bytes) - # 将 base64 字节转换回字符串 - base64_string = base64_bytes.decode('utf-8') - return base64_string - - f = open(log_file, 'a') + # if pty is used, we cannot use nohup, build and pass command differently + f = open(log_file, "a") converted_cmd = [ - sys.executable, - "-m", - "astune.utils.pty", - "--human-cmd", f"'{string_to_base64(self.cmd[0])}'", - "--dir", self.dir, - "--env", json.dumps(env_dict), - ] - print('running pty command:', ' '.join(converted_cmd)) + sys.executable, + "-m", + "ajet.utils.pty", + "--human-cmd", + f"'{string_to_base64(self.cmd[0])}'", + "--dir", + self.dir, + "--env", + json.dumps(env_dict), + ] + print_dict( + { + "Action": "Launching command via PTY", + "Command": " ".join(self.cmd), + "LogFile": str(log_file), + "Converted": " ".join(converted_cmd), + }, + header="Smart Daemon Launch - PTY", + ) proc = subprocess.Popen( converted_cmd, stdout=f, stderr=subprocess.STDOUT, stdin=subprocess.DEVNULL, cwd="./", - env={'ScriptHash': self.script_hash, **env_dict}, - start_new_session=True # Start new session + env={"ScriptHash": self.script_hash, **env_dict}, + start_new_session=True, # Start new session ) f.close() # Close append handle pgid = os.getpgid(proc.pid) # write pgid to {log_file}.pgid - with open(self.pgid_file, 'w') as f_pgid: + with open(self.pgid_file, "w") as f_pgid: f_pgid.write(str(pgid)) # Monitor log file for success string or timeout start_time = time.time() f_read = "" previous_r_print = False - with open(log_file, 'r') as f: - while time.time() - start_time < launch_wait_time: + with open(log_file, "r") as f: + while (time.time() - start_time) < launch_wait_time: f_read_ = f.read() - inc_read = f_read_[len(f_read):] + inc_read = f_read_[len(f_read) :] f_read = f_read_ # Update f_read to the latest content - if success_std_string: + + if success_std_string_arr: # Move to end of file and read new content - if success_std_string in f_read: - print(f"Found success string '{success_std_string}' in output") + hit_success_string = False + for success_std_string in success_std_string_arr: + if success_std_string in f_read: + hit_success_string = True + hit_success_string_content = success_std_string + print(f"Found success string '{hit_success_string_content}' in output") + break + # if we have reached finish line, then break + if hit_success_string: + f_read_trim = inc_read # .replace("\n", " ") + print(f"Waiting for process launch [PGID {pgid}, PID {proc.pid}] ({f_read_trim})") break + + # let hold for one second time.sleep(1) remaining = int(launch_wait_time - (time.time() - start_time)) - f_read_trim = inc_read.replace('\n', ' ') + + # trim output for printing + f_read_trim = inc_read # .replace("\n", " ") if f_read_trim: - if previous_r_print: print('') - print(f"Waiting for process launch... {remaining}s remaining ({f_read_trim})") + if previous_r_print: + print("") + print(f"Waiting for process launch [PGID {pgid}, PID {proc.pid}]... {remaining}s remaining ({f_read_trim})") previous_r_print = False else: - print(f"\rWaiting for process launch... {remaining}s remaining", end='', flush=True) + print( + f"\rWaiting for process launch [PGID {pgid}, PID {proc.pid}]... {remaining}s remaining", + end="", + flush=True, + ) previous_r_print = True if remaining % 10 == 0: @@ -259,14 +374,35 @@ def string_to_base64(s): if success_std_string: raise TimeoutError(f"Process did not output success string '{success_std_string}' within {launch_wait_time} seconds") - logger.success(f"Successfully launched {self.cmd} with PID {proc.pid}") + logger.success(f"Successfully launched {self.cmd} with PID {proc.pid} (Discovered {hit_success_string_content})") + print_dict( + { + "Result": "Successfully launched", + "Command": " ".join(self.cmd), + "PID": proc.pid, + } + ) + return hit_success_string_content except Exception as e: logging.error(f"Error launching script: {e}") - raise + raise e + class LaunchCommandWhenAbsent(LaunchWhenAbsent): - def __init__(self, full_argument_list: List[str], dir = None, tag = "", use_pty=False): + """ + Launch a command as a detached process if not already running, using a hash of the command for uniqueness. + """ + + def __init__(self, full_argument_list: List[str], dir=None, tag="", use_pty=False): + """ + Initialize with the full argument list for the command. + Args: + full_argument_list (List[str]): Command and arguments to launch. + dir (str, optional): Working directory. + tag (str, optional): Tag for log file naming. + use_pty (bool, optional): Whether to use PTY for the process. + """ if not dir: self.dir = os.getcwd() else: @@ -277,18 +413,18 @@ def __init__(self, full_argument_list: List[str], dir = None, tag = "", use_pty= full_argument_list_compute_hash = full_argument_list.copy() if full_argument_list_compute_hash[0] == sys.executable: - full_argument_list_compute_hash[0] = 'python' + full_argument_list_compute_hash[0] = "python" hash_items = full_argument_list_compute_hash + [str(self.dir)] - self.script_hash = hashlib.md5(''.join(hash_items).encode()).hexdigest()[:8] + self.script_hash = hashlib.md5("".join(hash_items).encode()).hexdigest()[:8] if self.use_pty: assert len(full_argument_list) == 1 self.cmd = full_argument_list else: - self.cmd = ['nohup'] + full_argument_list + self.cmd = ["nohup"] + full_argument_list # raise ValueError(self.script_hash) - log_dir = Path("launcher_record/companion_logs/companion") + log_dir = Path("saved_experiments/companion_logs/companion") log_dir.mkdir(parents=True, exist_ok=True) hostname = os.uname().nodename if tag: diff --git a/ajet/utils/testing_utils.py b/ajet/utils/testing_utils.py new file mode 100644 index 00000000..faf1e672 --- /dev/null +++ b/ajet/utils/testing_utils.py @@ -0,0 +1,268 @@ +# flake8: noqa E261, E131, E241 + +import os +import subprocess +import sys +import time +from typing import List + +import requests +from beast_logger import print_dict +from loguru import logger + +from ajet.utils.dynamic_import import dynamic_import +from ajet.utils.sington import singleton + + +class TestSuccessException(Exception): + """ + All test is done, end the program early with exception. + """ + + pass + + +class TestFailException(Exception): + """ + Test has failed, end the program early with exception. + """ + + pass + + +class BaseProbe(object): + """ + The basic test probe class, capture keyword if matched `self.probe_list`, and do test. + """ + + def __init__(self): + self.probe_list: List[str] = [] + + def __call__(self, key: str, log_dict: dict): + raise NotImplementedError + + def mock(self, key: str): + raise NotImplementedError + + +def get_test_lambda(test_name) -> BaseProbe: + test_cls = dynamic_import(test_name)() + return test_cls + + +def _test_if_test_mode(key, value, config): + from ajet.backbone.warm_up import init_parallel_rollout_logger + + if not config.ajet.execute_test: + return + if config.ajet.execute_test == "do_not_test": + return + init_parallel_rollout_logger(config.ajet.experiment_name) + test_lambda = get_test_lambda(config.ajet.execute_testing_lambda) + if key not in test_lambda.probe_list: + return + return test_lambda(key, value) + + +def _mock_if_test_mode(key, value, config): + if not config.ajet.execute_test: + return value + if config.ajet.execute_test == "do_not_test": + return value + test_lambda = get_test_lambda(config.ajet.execute_testing_lambda) + if key not in test_lambda.probe_list: + return value + return test_lambda.mock(key) + + +def send_test_result( + git_hash: str, + target: str, + status: str, + status_detail: str = "", + req_txt: str | None = None, + append_log: str | None = None, + data_dashboard_url: str | None = None, + timeout: float = 10.0, +) -> dict: + """ + Post a single experiment result to the /report_test_result endpoint. + Raises requests.HTTPError on non-2xx responses. + """ + + access_token = os.environ.get("BENCHMARK_ACCESS_TOKEN", None) + if not access_token: + logger.error("Cannot report to benchmark site, missing administrator token (`BENCHMARK_ACCESS_TOKEN` env variable).") + return {} + + payload = { + "access_token": access_token, + "git_hash": git_hash, + "target": target, + "status": status, + "status_detail": status_detail, + "req_txt": req_txt or "", + "append_log": append_log or "", + "data_dashboard_url": data_dashboard_url or "", + } + resp = requests.post( + r"https://benchmark-report.agent-matrix.com/report_test_result", + json=payload, + timeout=timeout, + ) + resp.raise_for_status() + return resp.json() + + +def populate_test_env_metadata(workspace_dir: str) -> tuple[str, str]: + """Capture git hash and pip freeze output, store them in env, return both.""" + git_hash = "unknown" + try: + git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=workspace_dir).decode().strip() + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + req_txt = "" + try: + req_txt = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], cwd=workspace_dir).decode().strip() + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + return git_hash, req_txt + + +def update_benchmark_status(status, status_detail, append_log="", data_dashboard_url=""): + if "AJET_GIT_HASH" not in os.environ: + return + + git_hash = os.environ["AJET_GIT_HASH"] + req_txt = os.environ["AJET_REQ_TXT"] + target_name = os.environ["AJET_BENCHMARK_NAME"] + + if not append_log: + append_log = status_detail + + send_test_result( + git_hash=git_hash, + target=target_name, + status=status, + status_detail=status_detail, # + req_txt=req_txt, # get pip freeze + append_log=append_log, + data_dashboard_url=data_dashboard_url, + timeout=10.0, + ) + + +class BenchmarkProbe(BaseProbe): + """ + A benchmark probe to test reward during training. + Major module input: + - self.reward_expectation: dict, key is step, value is [low, high] expected reward range + - self.reward_expectation_avg_window: int, number of steps to average reward over + - self.expected_train_time: int, expected training time in seconds + """ + + def __init__(self): + # fmt: off + self.expected_train_time = 3600 * 24 # 24 hours + self.begin_time = time.time() + self.reward_array = [] + self.reward_expectation_avg_window = 5 + self.reward_expectation = { + # step : expected local average reward range + # step : [low, high ] + 5 : [0.10, 99999.0], + 10 : [0.45, 99999.0], + 20 : [0.68, 99999.0], + 30 : [0.85, 99999.0], + } + # fmt: on + self.probe_list = ["reward_probe"] + + self.reward_key = "reward_for_test_robot" + self.probe_key = "reward_probe" + + def __call__(self, key, log_dict): + reward = self.reward_key + if key == self.probe_key: + step = log_dict["step"] + + if time.time() - self.begin_time > self.expected_train_time: + msg = f"Training time exceeded expected limit of {self.expected_train_time} seconds." + update_benchmark_status( + status="fail", + status_detail=msg, + append_log=msg, + ) + raise TestFailException(msg) + + # if new data, add + logger.bind(benchmark=True).info(f"log_dict: {str(log_dict)}") + logger.bind(benchmark=True).info(f"reward_key: {str(reward)}") + logger.bind(benchmark=True).info(f"self.reward_array before: {str(self.reward_array)}") + if reward in log_dict: + reward = log_dict[reward] + self.reward_array += [reward] + + update_benchmark_status( + status="running", + status_detail=f"Current step: {step}", + append_log=f"Step {step}: reward logged, {str(self.reward_array)}.", + data_dashboard_url=log_dict["data_dashboard_url"], + ) + + # begin test + if step in self.reward_expectation: + # compute local average reward + if len(self.reward_array) == 0: + err = f"No reward logged at step {step}" + update_benchmark_status( + status="fail", + status_detail=err, + ) + raise TestFailException(err) + # compute local average reward over last self.reward_expectation_avg_window steps + local_avg_reward = sum(self.reward_array[-self.reward_expectation_avg_window :]) / min(self.reward_expectation_avg_window, len(self.reward_array)) + # get expected range + low, high = self.reward_expectation[step] + # log + msg = f"[TestProbe] Step {step}: local average reward over last self.reward_expectation_avg_window steps: {local_avg_reward:.4f}, expected range: [{low}, {high}]" + logger.bind(benchmark=True).info(msg) + update_benchmark_status( + status="running", + status_detail=msg, + ) + # check + if not (low <= local_avg_reward <= high): + # test failed + print_dict( + { + "step": step, + "local_avg_reward": local_avg_reward, + "expected_low": low, + "expected_high": high, + }, + mod="benchmark", + ) + err = f"[TestProbe] Reward test failed at step {step}: local average reward {local_avg_reward:.4f} not in expected range [{low}, {high}]" + logger.bind(benchmark=True).error(err) + update_benchmark_status( + status="fail", + status_detail=err, + ) + raise TestFailException(err) + else: + # test passed + msg = f"[TestProbe] Reward test passed at step {step}." + logger.bind(benchmark=True).info(msg) + update_benchmark_status(status="running", status_detail=msg) + # congrats, all tests passed, let's crash and escape this test early. + if step == max(self.reward_expectation.keys()): + msg = "[TestProbe] All reward tests passed. Exiting training early." + logger.bind(benchmark=True).info(msg) + update_benchmark_status( + status="successful", + status_detail=msg, + ) + raise TestSuccessException(msg) diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py new file mode 100644 index 00000000..6981408c --- /dev/null +++ b/ajet/utils/thread_executors.py @@ -0,0 +1,20 @@ +from ajet.utils.sington import singleton +import concurrent.futures + + +@singleton +class SharedInterchangeThreadExecutor: + def __init__(self, max_workers=64): + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + + def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: + return self.executor + + +@singleton +class SharedInferenceTrackerThreadExecutor: + def __init__(self, max_workers=64): + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) + + def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor: + return self.executor diff --git a/ajet/utils/tokenizer.py b/ajet/utils/tokenizer.py new file mode 100644 index 00000000..94ab8007 --- /dev/null +++ b/ajet/utils/tokenizer.py @@ -0,0 +1,43 @@ +import copy +import json +from typing import Dict, List + + +def cleanup_messages(messages: List[Dict]) -> List[Dict]: + "A temperary fix for tool_calls being str instead of dict" + messages_copied = copy.deepcopy(messages) + for m in messages_copied: + if "tool_calls" not in m: + continue + for t in m["tool_calls"]: + if "function" not in t or "arguments" not in t["function"]: + continue + if isinstance(t["function"]["arguments"], str): + try: + t["function"]["arguments"] = json.loads(t["function"]["arguments"]) + except Exception: + pass + return messages_copied + + +def ajet_apply_chat_template( + tokenizer, + conversation, + tools, + add_generation_prompt: bool = False, + tokenize: bool = True, +): + conversation = cleanup_messages(conversation) + if tools: + return tokenizer.apply_chat_template( + conversation, + tools, + add_generation_prompt=add_generation_prompt, + tokenize=tokenize, + ) + else: + return tokenizer.apply_chat_template( + conversation, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + ) diff --git a/ajet/utils/vsdb.py b/ajet/utils/vsdb.py new file mode 100644 index 00000000..21d0369b --- /dev/null +++ b/ajet/utils/vsdb.py @@ -0,0 +1,63 @@ +import os +import pickle + +""" +Ray Distributed Debugger VSCode Extension (Recommended) + +Starting with Ray 2.39, Anyscale has introduced the +`Ray Distributed Debugger `_ VSCode extension. + +1. Install the Ray Distributed Debugger extension in VSCode. + +2. In AgentJet project: + + 2-1. In the place your want to set a conditional breakpoint, write + `from ajet import bp; bp("TAG_1")` + + 2-2. When launching the training process, add `--debug` argument + `ajet --conf your_config.yaml --debug="TAG_1"` + + 2-3. Open Tab "Ray Distributed Debugger" in VSCode, and just wait until the breakpoint is hit. + +""" + + +def vscode_conditional_breakpoint(tag=None, once=True): + env_tag = f"HIT_BREAKPOINT_REC_{tag}" + if not os.getenv("RAY_DEBUG_POST_MORTEM"): + return + if tag is None: + if once: + if os.getenv(env_tag, "") != "1": + os.environ[env_tag] = "1" + breakpoint() + return + else: + breakpoint() + return + else: + debug_tags = os.getenv("DEBUG_TAGS", "").split("|") + if tag in debug_tags: + if once: + if os.getenv(env_tag, "") != "1": + os.environ[env_tag] = "1" + breakpoint() + return + else: + breakpoint() + return + + +def objdump(obj, file="objdump.tmp"): + with open(file, "wb+") as f: + pickle.dump(obj, f) + return + + +def objload(file="objdump.tmp"): + import os + + if not os.path.exists(file): + return + with open(file, "rb") as f: + return pickle.load(f) diff --git a/ajet/workflow.py b/ajet/workflow.py new file mode 100644 index 00000000..58c8757d --- /dev/null +++ b/ajet/workflow.py @@ -0,0 +1,126 @@ +from typing import List + +from pydantic import BaseModel, Field + +from ajet import AjetTuner +from ajet.schema.task import WorkflowOutput, WorkflowTask + + +class Workflow(BaseModel): + model_config = {"extra": "allow"} + name: str = Field(default="default_workflow", description="Name of the workflow.") + trainable_targets: List[str] | None = Field( + default=None, + description="List of agents to be fine-tuned. When None, all agents are trainable.", + ) + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + """Run the workflow on a given task.""" + raise NotImplementedError + + +""" +How to define a trainable workflow 🚀: + +1. Single agent scenario 🤖: + + Simply set `model` argument to `tuner.as_agentscope_model()` when initializing your agent. + This is a helpful example when you: + - 🌟 Know exactly which agents should be trained, or the number of agents are small; + - ✨ Already finished basic debugging of your workflow using a fixed model such as qwen-max; + - 🎇 Do not requires changing which agents to be trained on the fly. + + ----- EXAMPLE ----- + + - Suppose you have a react agent that looks like this: + + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=DashScopeChatModel(model='qwen-max'), + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), + ) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg, structured_model=FinalResult) + final_answer = extract_final_answer(result) + + + - Then all you have to do is to wrap it in a workflow class: + + [+] class ExampleMathLearn(Workflow): + [+] name: str = "math_agent_workflow" + [+] async def execute(self, task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + [ ] from agentscope.agent import ReActAgent + [ ] from agentscope.formatter import DashScopeChatFormatter + [ ] from agentscope.memory import InMemoryMemory + [ ] from agentscope.tool import Toolkit, execute_python_code + [ ] self.toolkit = Toolkit() + [ ] self.toolkit.register_tool_function(execute_python_code) + [ ] self.agent = ReActAgent( + [ ] name="math_react_agent", + [ ] sys_prompt=system_prompt, + [+-] model=tuner.as_agentscope_model(), + [ ] formatter=DashScopeChatFormatter(), + [ ] toolkit=self.toolkit, + [ ] memory=InMemoryMemory(), + [ ] ) + [+] query = task.task.main_query + [ ] msg = Msg("user", query, role="user") + [ ] result = await self.agent.reply(msg, structured_model=FinalResult) + [ ] final_answer = extract_final_answer(result) + [+] return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) + + +2. Multi-agent scenario 🤝: + + Use `register_model` method of `ModelTuner` to register different agent targets. + This is extremely helpful when you want to + - 🌟 Achieve fine-grained control over which agents to be fine-tuned; + - ✨ Define what model agents should use when they are NOT being tuned; + - ⚡ Change which trainable agent targets on the fly without modifying the workflow code. + + ----- EXAMPLE ----- + + [ ] roles = ["werewolf"] * 3 + ["villager"] * 3 + ["seer", "witch", "hunter"] + [ ] players = [] + [ ] for i, role in enumerate(roles): + [ ] debug_model_for_good_guys = OpenAIChatModel(model_name="qwen-max", stream=False) + [ ] debug_model_for_bad_guys = OpenAIChatModel(model_name="qwen-plus", stream=False) + [ ] chosen_model = debug_model_for_good_guys if role != "werewolf" else debug_model_for_bad_guys # 🌟 + [ ] players += [ReActAgent( + [ ] name=f"Player{i + 1}", + [ ] sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), + [-] model=chosen_model, + [+] model=tuner.as_agentscope_model(f"Player{i + 1}", role, debug_model=chosen_model), + [ ] formatter=OpenAIMultiAgentFormatter(), + [ ] )] + + +[ ] roles = ["werewolf"] * 3 + ["villager"] * 3 + ["seer", "witch", "hunter"] +[ ] players = [] +[ ] for i, agent_role in enumerate(roles): +[ ] if agent_role != "werewolf": +[ ] chosen_model_for_current_agent = OpenAIChatModel(model_name="qwen-max", stream=False) +[ ] else: +[ ] chosen_model_for_current_agent = OpenAIChatModel(model_name="qwen-plus", stream=False) +[ ] players += [ReActAgent( +[ ] name=f"Player{i + 1}", +[ ] sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), +[ ] model=agentscope_model, +[ ] model=tuner.as_agentscope_model( +[ ] agent_name=f"Player{i + 1}", +[ ] target_tag=agent_role, # 🌟 tag agents with their role +[ ] debug_model=chosen_model_for_current_agent # 🌟 assign a debug model, ONLY used when we are NOT training this agent +[ ] ) +[ ] formatter=OpenAIMultiAgentFormatter(), +[ ] )] + +""" diff --git a/astune/agent_flow.py b/astune/agent_flow.py deleted file mode 100644 index 8f48957e..00000000 --- a/astune/agent_flow.py +++ /dev/null @@ -1,175 +0,0 @@ -import time -import os - -from loguru import logger -from astune.env_service_client.env_client import EnvClient -from astune.utils.utils import convert_tool_to_user_message -from astune.schema.trajectory import Reward -from astune.context_manager.cmt_linear import CMTLinear, ExtendedMessage -from astune.context_manager.cmt_linear_think import LinearThinkCMT -from astune.context_manager.cmt_context_clip import SelfContextClipCMT -from astune.context_manager.cmt_sliding_window import SlidingWindowCMT -from typing import Any, Dict, List, Union, Callable -from beast_logger import print_listofdict -import threading - -log_generate_lock = threading.Lock() - -class BaseAgentFlow(object): - - def __init__(self, - llm_chat_fn: Callable, - tokenizer: Any, - config, - **kwargs): - self.tokenizer = tokenizer - self.instruction_template_ids = self.tokenizer.encode("<|im_start|>user\n") - self.response_template_ids = self.tokenizer.encode("<|im_start|>assistant\n") - self.cmt: Union[CMTLinear, LinearThinkCMT, Any, None] = None - self.alien_llm_chat_fn: Union[Callable, None] = None - self.llm_chat_fn: Callable = llm_chat_fn - self.config = config - # self.console_debug_mode: bool = False - self.max_steps: int = self.config.astune.rollout.multi_turn.max_steps - self.max_model_len: int = self.config.astune.rollout.max_model_len - self.max_env_len: int = self.config.astune.rollout.max_env_len - -class AgentFlow(BaseAgentFlow): - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.use_step_reward_from_env: bool = self.config.astune.rollout.get("use_step_reward_from_env", False) - self.step_reward = [] - - - def execute(self, init_messages: List[dict], env: EnvClient, task_core_arg) -> CMTLinear: - obs_window = task_core_arg.obs_window - task_thread_index = task_core_arg.task_thread_index - - # 1. 🚀 Initialize messages - if self.config.astune.context_manager.context_manager_type == "linear": - self.cmt = CMTLinear(self.config, self.tokenizer) - elif self.config.astune.context_manager.context_manager_type == "linear_think": - self.cmt = LinearThinkCMT(self.config, self.tokenizer) - elif self.config.astune.context_manager.context_manager_type == "context_selfclip": - self.cmt = SelfContextClipCMT(self.config, self.tokenizer, self.llm_chat_fn) - elif self.config.astune.context_manager.context_manager_type == "sliding_window": - self.cmt = SlidingWindowCMT(self.config, self.tokenizer, self.llm_chat_fn) - else: - raise ValueError(f"Unsupported context template: {self.config.astune.context_manager.context_manager_type}") - - assert not (self.config.astune.rollout.force_think and self.config.astune.rollout.force_no_think), "Cannot force both think and no_think" - add_nothink = self.config.astune.rollout.force_no_think - - self.cmt.save_init_input(init_messages, add_nothink) - - request_id: str = "" - for act_step in range(self.max_steps): - # 2. 🔄 Update thread progress - obs_window['step'][task_thread_index] = act_step - if (obs_window['stop'] is not None) and obs_window['stop'][task_thread_index]: # Check if the thread should obs_window['stop'] (because other threads have completed, making this thread useless) - self.cmt.discarded = True - break - - # 3. ⏮️ get previous steps - try: - step_input_message_arr = self.cmt.prepare_next_llm_context() - except Exception as e: - print_listofdict(self.cmt.to_role_content(self.cmt.full_context), mod='exception', header="Before Crash") - raise e - - # 4. ⚠️ check token overflow - is_safe, info = self.cmt.check_context_token_num_safe(step_input_message_arr) - if not is_safe: - logger.warning(f"[{info}] detected at step {act_step}. Current token count exceeds the limit.") - self.cmt.is_terminated = True - break - - # 5. 🤖 call llm - llm_output = self.llm_chat_fn(step_input_message_arr, request_id=request_id) - if (obs_window['stop'] is not None) and obs_window['stop'][task_thread_index]: # Check if the thread should obs_window['stop'] (because other threads have completed, making this thread useless) - self.cmt.discarded = True - break - - # 6. 💾 save llm output - self.cmt.save_llm_output(llm_output, input_msg_ref=step_input_message_arr) - obs_window['token'][task_thread_index] += self.cmt.generated_token_cnt - - # 7. 🌍 world interaction - try: - env_output = env.step( - instance_id=task_core_arg.task_env_uuid, - action={"content": self.cmt.prepare_world_interaction(), "role": "assistant"}, - params={"step_skip_action": self.config.astune.rollout.step_skip_action} - ) - if env_output["state"]["role"] == "tool": - env_output["state"] = convert_tool_to_user_message(env_output["state"], self.tokenizer, format="qwen") - # if self.console_debug_mode: - # if isinstance(env_output["state"], dict): - # print_listofdict( - # step_input_message_arr + - # [{'role': 'llm_latest', 'content': llm_output['content']}] + - # [{'role': 'env', 'content': env_output["state"]['content']}] - # , mod='c') - except Exception as e: - logger.bind(exception=True).exception(f"call env.step error with {e}") - self.cmt.is_terminated = True - state = {"content": str(e), "role": "user"} - env_output = { - "reward": 0, - "is_terminated": True, - "state": state, - } - - # 8. 📥 save environment output - state = env_output["state"] - state.pop('tool_calls', None) - self.cmt.save_env_output(state, input_msg_ref=step_input_message_arr, add_nothink=add_nothink) - self.cmt.round_cnt += 1 - if self.use_step_reward_from_env: - self.step_reward += [env_output["reward"]] - - # 9. 🔚 determine if the episode is terminated - self.cmt.is_terminated = env_output["is_terminated"] - if self.cmt.is_terminated: - break - - self.cmt.ensure_terminate_rollout_stage() - obs_window['step'][task_thread_index] = -1 - raw_reward = 0 - raw_reward = env.evaluate(task_core_arg.task_env_uuid, params={"sparse": False}) - if raw_reward >= 1: - success_rate = 1.0 - else: - success_rate = 0.0 - if not self.use_step_reward_from_env: - if self.config.astune.rollout.add_special_success_reward: - if success_rate == 1: - raw_reward = 1.0 + raw_reward * 0.5 - else: - raw_reward = 0.0 + raw_reward * 0.5 - if self.config.astune.rollout.binary_reward: - raw_reward = success_rate - self.cmt.process_reward( - reward_structure = Reward( - raw_reward=raw_reward, - raw_step_reward=None, - success_rate=success_rate, - madness=0, - description="Success=1, Failure=0" - ) - ) - else: - self.cmt.process_reward( - reward_structure = Reward( - raw_reward=raw_reward, - raw_step_reward=self.step_reward, - success_rate=success_rate, - madness=0, - description="Step Reward from Environment" - ) - ) - - self.cmt.remove_last_context() - - return self.cmt diff --git a/astune/agentscope_flow.py b/astune/agentscope_flow.py deleted file mode 100644 index 375a56e6..00000000 --- a/astune/agentscope_flow.py +++ /dev/null @@ -1,102 +0,0 @@ -import threading -import importlib -import torch -import copy -import asyncio -from astune.env_service_client.env_client import EnvClient -from astune.agent_flow import BaseAgentFlow -from astune.schema.trajectory import Reward, Trajectory -from astune.context_manager.cmt_linear import CMTLinear, ExtendedMessage -from astune.protocol.agentscope_protocol import AgentScopeLearnProtocol -from astune.context_manager.cmt_linear import replace_token_ids, CMTLinear -from astune.schema.trajectory import Sample, Reward -from typing import Any, Dict, List, Union, Tuple -from astune.context_manager.cmt_agentscope import BeyondAgentProxy -from astune.schema.task import Task, TaskLaunchCoreArgument - -log_generate_lock = threading.Lock() - -class AgentScopeWorkflow(BaseAgentFlow): - - def execute(self, init_messages: List[dict], env: EnvClient, task_core_arg: TaskLaunchCoreArgument) -> CMTLinear: - obs_window = task_core_arg.obs_window - task_thread_index = task_core_arg.task_thread_index - task_batch_index = task_core_arg.task_batch_index - task_tag = task_core_arg.task_tag - task_id = task_core_arg.task_id - - # fetch learn protocol - protocol = self.config.astune.rollout.agentscope_learn_protocol - module_, class_ = protocol.split('->') - protocol_cls: AgentScopeLearnProtocol = getattr(importlib.import_module(module_), class_) - agentscope_protocol = protocol_cls(trainer='astune-trinity', agentflow_name='appworld') # type: ignore - - def env_step_fn(action: dict) -> Tuple[str, float, bool, dict]: - obs_window['step'][task_thread_index] += 1 - env_output = env.step( - instance_id=task_core_arg.task_env_uuid, - action=action, - ) - obs = "" - assert isinstance(env_output, dict) - if ('content' not in env_output["state"]) and ('error' in env_output["state"]): - obs = f"[Error from environment: {env_output['error']}]" - elif (env_output["state"]['content']==""): - obs = 'Warning: the environment does not provide any feedback, please provide valid inpu and try again.' - else: - obs = env_output["state"]['content'] - reward = 0 - info = {} - terminate = env_output["is_terminated"] - return obs, reward, terminate, info - def should_interrupt_fn() -> bool: - if (obs_window['stop'] is not None) and obs_window['stop'][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless) - return True - return False - def generated_token_callback_fn(token_array): - with log_generate_lock: - obs_window['token'][task_thread_index] += len(token_array) - - beyondagent_proxy = BeyondAgentProxy( - llm_chat_fn=self.llm_chat_fn, - tokenizer=self.tokenizer, - config=self.config, - model_name='beyondagent-proxy', - api_key='dummy-api-key', - task_batch_index=task_batch_index, - task_tag=task_tag, - task_id=task_id, - env_step_fn=env_step_fn, - should_interrupt_fn=should_interrupt_fn, - generated_token_callback_fn=generated_token_callback_fn, - ) - - beyondagent_proxy.update_agentscope_input_dictionary(task_core_arg=task_core_arg) - beyondagent_proxy = asyncio.run(agentscope_protocol.agentscope_execute(init_messages, beyondagent_proxy, self.config)) - beyondagent_proxy.update_judge_input_dictionary(task_core_arg=task_core_arg) - beyondagent_proxy.update_judge_input_dictionary(env=env) - beyondagent_proxy.update_judge_input_dictionary(grouped_steps=beyondagent_proxy.grouped_steps) - - raw_reward, is_success = beyondagent_proxy.get_judge().compute_reward( - beyondagent_proxy.get_judge_input_dictionary() - ) - - # evaluate - reward = Reward( - raw_reward=raw_reward, - raw_step_reward=None, - success_rate=1.0 if is_success else 0.0, - madness=0, - description="" - ) - beyondagent_proxy.process_reward(reward) - - # generate token before merging - beyondagent_proxy.remove_last_context() - beyondagent_proxy.task_id = task_id - beyondagent_proxy.task_tag = task_tag - beyondagent_proxy.group_merge() - beyondagent_proxy.process_reward(reward) - return beyondagent_proxy - - diff --git a/astune/backbone_trinity/__init__.py b/astune/backbone_trinity/__init__.py deleted file mode 100644 index 5ab258fe..00000000 --- a/astune/backbone_trinity/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from astune.backbone_trinity.register_flow import * \ No newline at end of file diff --git a/astune/backbone_trinity/register_flow.py b/astune/backbone_trinity/register_flow.py deleted file mode 100644 index 9900f8e7..00000000 --- a/astune/backbone_trinity/register_flow.py +++ /dev/null @@ -1,230 +0,0 @@ -import os -import uuid -import hydra -import openai -import numpy as np -import asyncio, uuid, copy -import threading - -from typing import Dict, List, Optional, Union -from trinity.common.experience import Experience -from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow -from trinity.common.workflows.agentscope.react.templates import TEMPLATE_MAP -from transformers import AutoTokenizer -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Literal, Callable, Union -from loguru import logger -from omegaconf import DictConfig -from tensordict import TensorDict -from torch.nn.utils.rnn import pad_sequence -from tqdm import tqdm -from verl import DataProto -from verl.utils.torch_functional import pad_sequence_to_length -from beast_logger import register_logger, print_dict, print_listofdict -from astune.schema.task import Task -from astune.utils.utils import run_async_coro__no_matter_what_the_fuck -from astune.parallel_env import DynamicRollout -from astune.schema.logprob import TokenAndProb -from astune.schema.task import Task -from astune.schema.trajectory import Sample -from omegaconf import OmegaConf - -class TrinityCompatWorkflow(DynamicRollout): - - def __init__(self, task, llm_handle, tokenizer, config, llm_mode="trinity", **kwargs): - - self.task = task - self.trinity_llm_model_client = llm_handle - self.tokenizer = tokenizer - self.config = config - self.llm_mode = "trinity" - - super().__init__( - config=self.config, - async_rollout_manager=None, - max_parallel=1, - max_llm_retries = 1, - tokenizer=tokenizer, - llm_mode=llm_mode, - **kwargs - ) - - def convert_task(self, task): - main_query = task.raw_task.get('main_query', "[not defined]") - task_id = task.raw_task.get('task_selector', str(uuid.uuid4().hex)) - env_type = task.raw_task.get('env_type', "[not defined]") - metadata = task.raw_task.get('metadata', {}) - init_messages = task.raw_task.get('init_messages', []) - - return Task( - main_query=main_query, - task_id=task_id, - env_type=env_type, - metadata=metadata, - init_messages=init_messages, - ) - - def thread_worker(self): - obs_window = { - 'stop': [False], - 'step': [0], - 'token': [0], - } - astune_task = self.convert_task(self.task) - return self.rollout_env_worker( - task=astune_task, - task_batch_index=0, - task_tag=f"T{astune_task.task_id}#R?", - mode="sample", - task_thread_index=0, - obs_window=obs_window - ) - - def run_in_new_thread(self): - # begin self.thread_worker in a new thread - # then wait for it to finish, and get the result - - result_holder = {} - exc_holder = {} - - def _target(): - try: - result_holder["result"] = self.thread_worker() - except Exception as e: - exc_holder["exc"] = e - - t = threading.Thread(target=_target, daemon=True) - t.start() - t.join() - - if "exc" in exc_holder: - raise exc_holder["exc"] - - return result_holder.get("result", None) - - -def read_astune_config(yaml_fp): - from hydra import initialize, compose - from omegaconf import DictConfig - - def load_hydra_config(config_path: str, config_name: str) -> DictConfig: - with initialize(config_path=config_path, version_base=None): - cfg = compose(config_name=config_name, overrides=[]) - return cfg - - dir_path = os.path.dirname(yaml_fp) - file_name = os.path.basename(yaml_fp) - return load_hydra_config(config_path=dir_path, config_name=file_name) - - -@WORKFLOWS.register_module("astune_workflow") -class astunetWorkflowWrap(Workflow): - is_async: bool = True - def __init__( - self, - config, - model: ModelWrapper, - task: Task, - auxiliary_models: Optional[List[openai.OpenAI]] = None, - ): - super().__init__( - task=task, - model=model, - auxiliary_models=auxiliary_models, - ) - self.config = config - self.task = task - - # 模拟openai的异步客户端 - self.model_client = model.get_openai_async_client() - # task_type 用于获取奖励函数 - # extract the query and the answer from the task - self.query = task.raw_task.get(task.format_args.prompt_key) # type: ignore [index] - self.answer = task.raw_task.get(task.format_args.response_key) # type: ignore [index] - self.task.workflow_args = { - "env_type": "appworld", - "task_id": self.task.task_id, - "instance_id": uuid.uuid4().hex, - } - - async def run_async(self): - - yaml_path = os.environ.get('ASTUNE_CONFIG_REDIRECT', None) - if yaml_path is None: - raise ValueError("ASTUNE_CONFIG_REDIRECT is not set in environment variables") - - cmt = TrinityCompatWorkflow( - task=self.task, - llm_handle=self.model_client, - tokenizer=AutoTokenizer.from_pretrained(self.model_client.model_path), - config=read_astune_config(os.path.relpath(yaml_path, os.path.dirname(__file__))), - ).run_in_new_thread() - - sample_final = [] - try: - sample_arr = cmt.group_tokenize() - except Exception as e: - cmt.generate_log(global_step=-1) - raise e - cmt.generate_log(global_step=-1) - sample_final += sample_arr - - - exps = [] - for index, sample in enumerate(sample_final): - sample: Sample - input_ids = sample.input_ids - prompt_ids = sample.prompt_ids - response_ids = sample.response_ids - attention_mask = sample.attention_mask - prompt_attention_mask = sample.prompt_attention_mask - response_attention_mask = sample.response_attention_mask - loss_mask = sample.loss_mask - prompt_loss_mask = sample.prompt_loss_mask - response_loss_mask = sample.response_loss_mask - position_ids = sample.position_ids - prompt_position_ids = sample.prompt_position_ids - response_position_ids = sample.response_position_ids - # cmt_tokenized["step_reward"] = self.reward_structure.step_reward[index] - - logprobs = sample.response_logprobs - try: - reward = cmt.reward_structure.step_reward - if isinstance(reward, list): - reward = reward[0] - except Exception as e: - reward = cmt.reward_structure.raw_reward - if not isinstance(reward, (float, int)): # if reward is still not a float or int, set it to 0.0 - reward = cmt.reward_structure.raw_reward - - if len(response_ids) + len(prompt_ids) == len(input_ids) and len(logprobs) == len(response_ids) and len(logprobs) > 0: - exp = Experience( - # eid=uuid.uuid4().hex, - tokens = input_ids, # [seq_length] prompt + response - prompt_length = len(prompt_ids), # Length of the prompt in tokens, used for generating attention masks - logprobs = logprobs, # [resp_length] - reward = reward, # - # advantages=None, - # returns=None, - info = {}, - metrics = {}, # for wandb logging (must be string:float) - response_text = "", # optional - prompt_text = "", # optional - #### for multi-turn experiences - action_mask = response_loss_mask, # 1 是训练 - messages=sample.messages, # - # tools, - #### for dpo experiences - # chosen, - # rejected, - # chosen_messages, - # rejected_messages, - #### for multi-modal data - # multi_modal_inputs - ) - exps += [exp] - else: - from vsdb import bp - bp("BUGX") - return exps diff --git a/astune/backbone_verl/trainer.py b/astune/backbone_verl/trainer.py deleted file mode 100644 index 5b4d6563..00000000 --- a/astune/backbone_verl/trainer.py +++ /dev/null @@ -1,1683 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -import json -import os -import uuid -import warnings -from collections import defaultdict -from copy import deepcopy -from dataclasses import dataclass, field -from enum import Enum -from pprint import pprint -from typing import Optional - -import numpy as np -import time -import ray -import torch -from omegaconf import OmegaConf, open_dict -from torch.utils.data import Dataset, Sampler -from torchdata.stateful_dataloader import StatefulDataLoader -from tqdm import tqdm - -from verl import DataProto -from verl.experimental.dataset.sampler import AbstractCurriculumSampler -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.base import Worker -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.config import AlgoConfig -from verl.trainer.ppo import core_algos -from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - process_validation_metrics, -) -from verl.trainer.ppo.reward import compute_reward, compute_reward_async -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi -from verl.utils.config import omega_conf_to_dataclass -from verl.utils.debug import marked_timer -from verl.utils.metric import reduce_metrics -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance -from verl.utils.torch_functional import masked_mean -from verl.utils.tracking import ValidationGenerationsLogger - - -from astune.parallel_env import ParallelEnvManager -from astune.schema.task import Task -from astune.schema.trajectory import Trajectory -from astune.utils.message import send_train_message -from beast_logger import register_logger, print_dict -from astune.context_manager.cmt_linear import CMTLinear -import os -import json -from typing import List, Dict, Union, Type -from loguru import logger - -WorkerType = type[Worker] - - -def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor: - """ - Compute reward for a batch of data. - Args: - data: DataProto object containing the input data. - return_dict: Whether to return a dictionary or just the reward tensor. - - Returns: - Tensor of shape (bs, response_len) if return_dict is False, - or a dict with 'reward_tensor' and 'reward_extra_info'. - """ - # Within DataFlow, world.execute() will pass a float score, which will be contained in the DataProto.non_tensor_batch('reward_scores') - - # Initialize reward tensor - reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) # (bs, reslen) - reward_extra_info = defaultdict(list) - - # Batch-level processing - prompt_ids_batch = data.batch["prompts"] # (bs, prompt_len) - prompt_lengths = prompt_ids_batch.shape[-1] - - # Get attention masks for all items - attention_masks = data.batch["attention_mask"] # (bs, total_len) - response_lengths = attention_masks[:, prompt_lengths:].sum(dim=1) # (bs, ) - - # Get reward scores - reward_scores_list = [item for item in data.non_tensor_batch["reward_scores"]] - reward_scores = torch.tensor(reward_scores_list, device=reward_tensor.device, dtype=torch.float32) # (bs, ) - - # Use advanced indexing to assign rewards (把reward放到response的最后一个token位置) - reward_tensor[torch.arange(len(data)), response_lengths - 1] = reward_scores - - if return_dict: - return { - "reward_tensor": reward_tensor, - "reward_extra_info": reward_extra_info, - } - else: - return reward_tensor - -def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto): - """ - Union the gen_batch_output with the batch based on task_id. - """ - map_task_id_to_index = {t.task_id:i for i, t in enumerate(tasks)} - gen_task_task_ids = gen_batch_output.non_tensor_batch['task_ids'] - indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids] - batch_extend = batch.select_idxs(indices) - batch_final = batch_extend.union(gen_batch_output) - return batch_final - - -class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - - Actor = 0 - Rollout = 1 - ActorRollout = 2 - Critic = 3 - RefPolicy = 4 - RewardModel = 5 - ActorRolloutRef = 6 - - -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - """ - - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - """Create Ray resource pools for distributed training. - - Initializes resource pools based on the resource pool specification, - with each pool managing GPU resources across multiple nodes. - For FSDP backend, uses max_colocate_count=1 to merge WorkerGroups. - For Megatron backend, uses max_colocate_count>1 for different models. - """ - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 - # that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name - ) - self.resource_pool_dict[resource_pool_name] = resource_pool - - self._check_resource_available() - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = { - node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) - for node, node_info in node_available_resources.items() - } - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] - ) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" - ) - - # check each resource pool can be satisfied, O(#resource_pools * #nodes) - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) - for node, available_gpus in node_available_gpus.items(): - if available_gpus >= num_gpus: - node_available_gpus[node] -= num_gpus - num_nodes -= 1 - if num_nodes == 0: - break - if num_nodes > 0: - raise ValueError( - f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" - + "cannot be satisfied in this ray cluster" - ) - - -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): - """Apply KL penalty to the token-level rewards. - - This function computes the KL divergence between the reference policy and current policy, - then applies a penalty to the token-level rewards based on this divergence. - - Args: - data (DataProto): The data containing batched model outputs and inputs. - kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. - kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". - - Returns: - tuple: A tuple containing: - - The updated data with token-level rewards adjusted by KL penalty - - A dictionary of metrics related to the KL penalty - """ - response_mask = data.batch["response_mask"] - token_level_scores = data.batch["token_level_scores"] - batch_size = data.batch.batch_size[0] - - # compute kl between ref_policy and current policy - # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. - kld = core_algos.kl_penalty( - data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty - ) # (batch_size, response_length) - kld = kld * response_mask - beta = kl_ctrl.value - - token_level_rewards = token_level_scores - beta * kld - - current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() - - # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch["token_level_rewards"] = token_level_rewards - - metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} - - return data, metrics - - -def compute_response_mask(data: DataProto): - """Compute the attention mask for the response part of the sequence. - - This function extracts the portion of the attention mask that corresponds to the model's response, - which is used for masking computations that should only apply to response tokens. - - Args: - data (DataProto): The data containing batched model outputs and inputs. - - Returns: - torch.Tensor: The attention mask for the response tokens. - """ - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] - - -def compute_grpo_outcome_advantage_new( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - task_index: np.ndarray, - rollout_index: np.ndarray, - epsilon: float = 1e-6, - norm_adv_by_std_in_grpo: bool = True, - config: Optional[AlgoConfig] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for GRPO, operating only on Outcome reward - (with only one scalar reward for each response). - - Args: - token_level_rewards: `(torch.Tensor)` - shape is (bs, response_length) - response_mask: `(torch.Tensor)` - shape is (bs, response_length) - task_index: `(np.ndarray)` - task_index array for grouping - epsilon: `(float)` - small value to avoid division by zero - norm_adv_by_std_in_grpo: `(bool)` - whether to scale the GRPO advantage - config: `(Optional[AlgoConfig])` - algorithm configuration object - - Note: - If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. - If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). - - Returns: - advantages: `(torch.Tensor)` - shape is (bs, response_length) - Returns: `(torch.Tensor)` - shape is (bs, response_length) - """ - scores = token_level_rewards.sum(dim=-1) # 1d-list - - id2score = defaultdict(list) - id2pointer = defaultdict(list) - id2mean = {} - id2std = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[task_index[i]].append(scores[i]) - id2pointer[task_index[i]].append(i) - - # compute mean and std - for idx in id2score: - # the score list - this_task_all_score = id2score[idx] - # get the rollout id list - this_task_all_rolloutid = [rollout_index[idx] for idx in id2pointer[idx]] - # for same rollout id sample, reduce mean - rolloutid2score = defaultdict(list) - rolloutid2meanscore = {} - for rolloutid, score in zip(this_task_all_rolloutid, this_task_all_score): - rolloutid2score[rolloutid].append(score) - for rolloutid in rolloutid2score: - rolloutid2meanscore[rolloutid] = torch.mean(torch.tensor(rolloutid2score[rolloutid])) - - this_task_all_score = list(rolloutid2meanscore.values()) - - if len(this_task_all_score) == 1: - # single sample for - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) - elif len(this_task_all_score) > 1: - scores_tensor = torch.stack(this_task_all_score) - id2mean[idx] = torch.mean(scores_tensor) - id2std[idx] = torch.std(scores_tensor) - # if id2std[idx] < 0.01: - # id2std[idx] = 0.01 - else: - raise ValueError(f"no score in prompt task_index: {idx}") - - for i in range(bsz): - if norm_adv_by_std_in_grpo: - scores[i] = (scores[i] - id2mean[task_index[i]]) / (id2std[task_index[i]] + epsilon) - else: - scores[i] = scores[i] - id2mean[task_index[i]] - - scores = scores.unsqueeze(-1) * response_mask - - return scores, scores - -def compute_advantage( - data: DataProto, - adv_estimator: AdvantageEstimator, - gamma: float = 1.0, - lam: float = 1.0, - num_repeat: int = 1, - norm_adv_by_std_in_grpo: bool = True, - config: Optional[AlgoConfig] = None, -) -> DataProto: - """Compute advantage estimates for policy optimization. - - This function computes advantage estimates using various estimators like GAE, GRPO, REINFORCE++, etc. - The advantage estimates are used to guide policy optimization in RL algorithms. - - Args: - data (DataProto): The data containing batched model outputs and inputs. - adv_estimator (AdvantageEstimator): The advantage estimator to use (e.g., GAE, GRPO, REINFORCE++). - gamma (float, optional): Discount factor for future rewards. Defaults to 1.0. - lam (float, optional): Lambda parameter for GAE. Defaults to 1.0. - num_repeat (int, optional): Number of times to repeat the computation. Defaults to 1. - norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in - GRPO. Defaults to True. - config (dict, optional): Configuration dictionary for algorithm settings. Defaults to None. - - Returns: - DataProto: The updated data with computed advantages and returns. - """ - # Back-compatible with trainers that do not compute response mask in fit - if "response_mask" not in data.batch.keys(): - data.batch["response_mask"] = compute_response_mask(data) - # prepare response group - if adv_estimator == AdvantageEstimator.GAE: - # Compute advantages and returns using Generalized Advantage Estimation (GAE) - advantages, returns = core_algos.compute_gae_advantage_return( - token_level_rewards=data.batch["token_level_rewards"], - values=data.batch["values"], - response_mask=data.batch["response_mask"], - gamma=gamma, - lam=lam, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - if config.get("use_pf_ppo", False): - data = core_algos.compute_pf_ppo_reweight_data( - data, - config.pf_ppo.get("reweight_method"), - config.pf_ppo.get("weight_pow"), - ) - elif adv_estimator == AdvantageEstimator.GRPO: - # Initialize the mask for GRPO calculation - grpo_calculation_mask = data.batch["response_mask"] - # If multi-turn, replace the mask with the relevant part of loss_mask - # Get length from the initial response mask - response_length = grpo_calculation_mask.size(1) - # This mask is the one intended for GRPO - grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] - # Call compute_grpo_outcome_advantage with parameters matching its definition - if config.task_norm_patch: - advantages, returns = compute_grpo_outcome_advantage_new( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=grpo_calculation_mask, - task_index=data.non_tensor_batch["uid"], - rollout_index=data.non_tensor_batch["rollout_ids"], - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - else: - advantages, returns = core_algos.compute_grpo_outcome_advantage( - token_level_rewards=data.batch["token_level_rewards"], - response_mask=grpo_calculation_mask, - index=data.non_tensor_batch["uid"], - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - # handle all other adv estimator type other than GAE and GRPO - adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) - adv_kwargs = { - "token_level_rewards": data.batch["token_level_rewards"], - "response_mask": data.batch["response_mask"], - "config": config, - } - if "uid" in data.non_tensor_batch: # optional - adv_kwargs["index"] = data.non_tensor_batch["uid"] - if "reward_baselines" in data.batch: # optional - adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] - - # calculate advantage estimator - advantages, returns = adv_estimator_fn(**adv_kwargs) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - return data - - -class BeyondAgentRayPPOTrainer: - """Distributed PPO trainer using Ray for scalable reinforcement learning. - - This trainer orchestrates distributed PPO training across multiple nodes and GPUs, - managing actor rollouts, critic training, and reward computation with Ray backend. - Supports various model architectures including FSDP, Megatron, and vLLM integration. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__( - self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - collate_fn=None, - train_sampler: Optional[Sampler] = None, - device_name=None, - ): - """ - Initialize distributed PPO trainer with Ray backend. - Note that this trainer runs on the driver process on a single CPU/GPU node. - - Args: - config: Configuration object containing training parameters. - tokenizer: Tokenizer used for encoding and decoding text. - role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. - resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. - ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. - processor: Optional data processor, used for multimodal data - reward_fn: Function for computing rewards during training. - val_reward_fn: Function for computing rewards during validation. - train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. - val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. - collate_fn: Function to collate data samples into batches. - train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. - device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. - """ - - # Store the tokenizer for text processing - self.tokenizer = tokenizer - self.processor = processor - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, "Currently, only support hybrid engine" - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping - self.ray_worker_group_cls = ray_worker_group_cls - self.device_name = device_name if device_name else self.config.trainer.device - self.validation_generations_logger = ValidationGenerationsLogger( - project_name=self.config.astune.project_name, - experiment_name=self.config.astune.experiment_name, - ) - - # if ref_in_actor is True, the reference policy will be actor without lora applied - self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 - - # define in-reward KL control - # kl loss control currently not suppoorted - if self.config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) - - if config.critic.enable is not None: - self.use_critic = bool(config.critic.enable) - elif self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - self.use_critic = True - else: - warnings.warn( - "Disabled critic as algorithm.adv_estimator != gae. " - "If it is not intended, please set critic.enable=True", - stacklevel=2, - ) - self.use_critic = False - - self._validate_config() - self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - - def _validate_config(self): - config = self.config - # number of GPUs total - n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes - minimal_bsz = n_gpus - - # 1. Check total batch size for data correctness - real_train_batch_size = config.astune.data.train_batch_size * config.astune.rollout.num_repeat - assert real_train_batch_size % minimal_bsz == 0, ( - f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " - f"({minimal_bsz})" - ) - - # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" - # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". - def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): - """Validate mutually exclusive micro batch size configuration options. - - Ensures that users don't set both deprecated micro_batch_size and - the new micro_batch_size_per_gpu parameters simultaneously. - - Args: - mbs: Deprecated micro batch size parameter value. - mbs_per_gpu: New micro batch size per GPU parameter value. - name (str): Configuration section name for error messages. - - Raises: - ValueError: If both parameters are set or neither is set. - """ - settings = { - "reward_model": "micro_batch_size", - "actor_rollout_ref.ref": "log_prob_micro_batch_size", - "actor_rollout_ref.rollout": "log_prob_micro_batch_size", - } - - if name in settings: - param = settings[name] - param_per_gpu = f"{param}_per_gpu" - - if mbs is None and mbs_per_gpu is None: - raise ValueError( - f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'." - ) - - if mbs is not None and mbs_per_gpu is not None: - raise ValueError( - f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove " - f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)." - ) - - # Actor validation done in ActorConfig.__post_init__ and validate() - actor_config = omega_conf_to_dataclass(config.actor_rollout_ref.actor) - actor_config.validate(n_gpus, config.astune.data.train_batch_size, config.actor_rollout_ref.model) - - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - if self.use_reference_policy: - # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref", - ) - - # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.astune.rollout.log_prob_micro_batch_size, - config.astune.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout", - ) - - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive( - config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" - ) - - if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: - print("NOTICE: You have both enabled in-reward kl and kl loss.") - - # critic - if self.use_critic: - critic_config = omega_conf_to_dataclass(config.critic) - critic_config.validate(n_gpus, config.astune.data.train_batch_size) - - if config.data.get("val_batch_size", None) is not None: - print( - "WARNING: val_batch_size is deprecated." - + " Validation datasets are sent to inference engines as a whole batch," - + " which will schedule the memory themselves." - ) - - # check eval config - if config.astune.rollout.val_kwargs.do_sample: - assert config.astune.rollout.temperature > 0, ( - "validation gen temperature should be greater than 0 when enabling do_sample" - ) - - print("[validate_config] All configuration checks passed successfully!") - - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): - """ - Creates the train and validation dataloaders. - """ - from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler - - if train_dataset is None: - train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor - ) - if val_dataset is None: - val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor - ) - self.train_dataset, self.val_dataset = train_dataset, val_dataset - - if train_sampler is None: - train_sampler = create_rl_sampler(self.config.data, self.train_dataset) - if collate_fn is None: - from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn - - collate_fn = default_collate_fn - - num_workers = self.config.data["dataloader_num_workers"] - - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.get("gen_batch_size", self.config.astune.data.train_batch_size), - num_workers=num_workers, - drop_last=True, - collate_fn=collate_fn, - sampler=train_sampler, - ) - - val_batch_size = self.config.data.val_batch_size # Prefer config value if set - if val_batch_size is None: - val_batch_size = len(self.val_dataset) - - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - batch_size=val_batch_size, - num_workers=num_workers, - shuffle=self.config.data.get("validation_shuffle", True), - drop_last=False, - collate_fn=collate_fn, - ) - - assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" - assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" - - print( - f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " - f"{len(self.val_dataloader)}" - ) - - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f"Total training steps: {self.total_training_steps}") - - try: - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - if OmegaConf.select(self.config, "critic.optim"): - self.config.critic.optim.total_training_steps = total_training_steps - except Exception as e: - print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - - def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): - """Dump rollout/validation samples as JSONL.""" - os.makedirs(dump_path, exist_ok=True) - filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") - - n = len(inputs) - base_data = { - "input": inputs, - "output": outputs, - "gts": gts, - "score": scores, - "step": [self.global_steps] * n, - } - - for k, v in reward_extra_infos_dict.items(): - if len(v) == n: - base_data[k] = v - - lines = [] - for i in range(n): - entry = {k: v[i] for k, v in base_data.items()} - lines.append(json.dumps(entry, ensure_ascii=False)) - - with open(filename, "w") as f: - f.write("\n".join(lines) + "\n") - - print(f"Dumped generations to {filename}") - - def _maybe_log_val_generations(self, inputs, outputs, scores): - """Log a table of validation samples to the configured logger (wandb or swanlab)""" - - generations_to_log = self.config.trainer.log_val_generations - - if generations_to_log == 0: - return - - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores, strict=True)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Log to each configured logger - self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - - def _validate(self): - data_source_lst = [] - reward_extra_infos_dict: dict[str, list] = defaultdict(list) - - # Lists to collect samples for the table - sample_inputs = [] - sample_outputs = [] - sample_gts = [] - sample_scores = [] - sample_turns = [] - - for test_data in self.val_dataloader: - test_batch = DataProto.from_single_dict(test_data) - - # repeat test batch - test_batch = test_batch.repeat( - repeat_times=self.config.astune.rollout.val_kwargs.n, interleave=True - ) - - # we only do validation on rule-based rm - if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": - return {} - - # Store original inputs - input_ids = test_batch.batch["input_ids"] - # TODO: Can we keep special tokens except for padding tokens? - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] - # sample_inputs.extend(input_texts) - - ground_truths = [ - item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch - ] - sample_gts.extend(ground_truths) - - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - if "interaction_kwargs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("interaction_kwargs") - if "agent_name" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("agent_name") - if "extras" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("extras") - - test_gen_batch = test_batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - - test_gen_batch.meta_info = { - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - "recompute_log_prob": False, - "do_sample": self.config.astune.rollout.val_kwargs.do_sample, - "validate": True, - "global_steps": self.global_steps, - } - print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") - - self.async_rollout_manager.wake_up() - main_val_dataset, test_normal_dataset, test_chanllenge_dataset = self.get_eval_dataset() - - print("=" * 10 + "start validate rollout" + "=" * 10) - trajectories, tasks, val_metrics = self.eval_dataset( - target_dataset=main_val_dataset, - target_dataset_name="main_val_dataset", - mode="validate", - epoch=f"test.1" - ) - print("=" * 10 + "end validate rollout" + "=" * 10) - test_output_gen_batch = self.parallel_env.to_dataproto(trajectories) - self.async_rollout_manager.sleep() - print("validation generation end") - - # Store generated outputs - output_ids = test_output_gen_batch.batch["responses"] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] - sample_outputs.extend(output_texts) - - test_batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object) - tasks = tasks[:len(main_val_dataset)] - test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch) - # test_batch = test_batch.union(test_output_gen_batch) - test_batch.meta_info["validate"] = True - - # evaluate using reward_function - if self.val_reward_fn is None: - raise ValueError("val_reward_fn must be provided for validation.") - result = self.val_reward_fn(test_batch, return_dict=True) - reward_tensor = result["reward_tensor"] - scores = reward_tensor.sum(-1).cpu().tolist() - sample_scores.extend(scores) - - reward_extra_infos_dict["reward"].extend(scores) - print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") - if "reward_extra_info" in result: - for key, lst in result["reward_extra_info"].items(): - reward_extra_infos_dict[key].extend(lst) - print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") - - # collect num_turns of each prompt - if "__num_turns__" in test_batch.non_tensor_batch: - sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) - - data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) - break # hack to escape the loop after one batch - - sample_inputs = [m['messages'][0]['content'] for m in test_batch.non_tensor_batch['messages']] - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - - metric_dict = val_metrics - - return metric_dict - - def init_workers(self): - """Initialize distributed training workers using Ray backend. - - Creates: - 1. Ray resource pools from configuration - 2. Worker groups for each role (actor, critic, etc.) - """ - self.resource_pool_manager.create_resource_pool() - - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role="actor_rollout", - profile_option=self.config.trainer.npu_profile.options, - ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cfg = omega_conf_to_dataclass(self.config.critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], - config=self.config.actor_rollout_ref, - role="ref", - profile_option=self.config.trainer.npu_profile.options, - ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls - - # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. - # Instead, directly pass different resource pool to different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - wg_kwargs = {} # Setting up kwargs for RayWorkerGroup - if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: - wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout - if OmegaConf.select(self.config.trainer, "profile_steps") is not None: - wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") - assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, ( - "worker_nsight_options must be set when profile_steps is set" - ) - wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( - OmegaConf.select(self.config.trainer, "worker_nsight_options") - ) - wg_kwargs["device_name"] = self.device_name - - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, - ray_cls_with_init=worker_dict_cls, - **wg_kwargs, - ) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - if self.use_critic: - self.critic_wg = all_wg["critic"] - self.critic_wg.init_model() - - if self.use_reference_policy and not self.ref_in_actor: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() - - if self.use_rm: - self.rm_wg = all_wg["rm"] - self.rm_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor_rollout"] - self.actor_rollout_wg.init_model() - - # create async rollout manager and request scheduler - self.async_rollout_mode = False - if self.config.astune.rollout.mode == "async": - from verl.experimental.agent_loop.agent_loop import AsyncLLMServerManager - from verl.experimental.agent_loop.agent_loop import AgentLoopManager - self.async_rollout_mode = True - agent_loop_manager = AgentLoopManager( - config=self.config, - worker_group=self.actor_rollout_wg, - ) - self.async_server_list = agent_loop_manager.async_llm_servers - self.async_rollout_manager = AsyncLLMServerManager(self.config, self.async_server_list) - - self.reward_fn = parse_reward_from_dataproto - self.val_reward_fn = parse_reward_from_dataproto - from concurrent.futures import ThreadPoolExecutor - self.parallel_env = ParallelEnvManager(config=self.config, async_rollout_manager=self.async_rollout_manager, max_parallel=self.config.astune.rollout.max_env_worker, tokenizer=self.tokenizer) - - def _save_checkpoint(self): - from verl.utils.fs import local_mkdir_safe - - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" - ) - - print(f"local_global_step_folder: {local_global_step_folder}") - actor_local_path = os.path.join(local_global_step_folder, "actor") - - actor_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") - ) - - remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) - if remove_previous_ckpt_in_save: - print( - "Warning: remove_previous_ckpt_in_save is deprecated," - + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" - ) - max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - - self.actor_rollout_wg.save_checkpoint( - actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep - ) - - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") - critic_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") - ) - self.critic_wg.save_checkpoint( - critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep - ) - - # save dataloader - local_mkdir_safe(local_global_step_folder) - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" - ) - with open(local_latest_checkpointed_iteration, "w") as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == "disable": - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError("load from hdfs is not implemented yet") - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == "auto": - if global_step_folder is None: - print("Training from scratch") - return 0 - else: - if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) - global_step_folder = self.config.trainer.resume_from_path - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") - # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) - - print(f"Setting global step to {self.global_steps}") - print(f"Resuming from {global_step_folder}") - - actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - # load critic - if self.use_critic: - self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, "data.pt") - if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) - self.train_dataloader.load_state_dict(dataloader_state_dict) - else: - print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") - - def _start_profiling(self, do_profile: bool) -> None: - """Start profiling for all worker groups if profiling is enabled.""" - if do_profile: - self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) - if self.use_reference_policy: - self.ref_policy_wg.start_profile() - if self.use_critic: - self.critic_wg.start_profile() - if self.use_rm: - self.rm_wg.start_profile() - - def _stop_profiling(self, do_profile: bool) -> None: - """Stop profiling for all worker groups if profiling is enabled.""" - if do_profile: - self.actor_rollout_wg.stop_profile() - if self.use_reference_policy: - self.ref_policy_wg.stop_profile() - if self.use_critic: - self.critic_wg.stop_profile() - if self.use_rm: - self.rm_wg.stop_profile() - - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): - """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch["attention_mask"] - batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) - world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) - # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) - batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix - ) - metrics.update(global_balance_stats) - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC - to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.astune.project_name, - experiment_name=self.config.astune.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - self.tracking_logger = logger - self.global_steps = 0 - - - # load checkpoint before doing anything - self._load_checkpoint() - - # wake and sleep to enforce param sync - self.async_rollout_manager.wake_up() - self.async_rollout_manager.sleep() - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - assert val_metrics, f"{val_metrics=}" - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - - # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - self.max_steps_duration = 0 - - prev_step_profile = False - curr_step_profile = ( - self.global_steps in self.config.trainer.profile_steps - if self.config.trainer.profile_steps is not None - else False - ) - next_step_profile = False - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - - with marked_timer("start_profile", timing_raw): - self._start_profiling( - not prev_step_profile and curr_step_profile - if self.config.trainer.profile_continuous_steps - else curr_step_profile - ) - - # from vsdb import bp - # bp("YYY") - batch_dict['index'] = torch.tensor([i for i in range(len(batch_dict['task_id']))], dtype=torch.long) - - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # add uid to batch - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - - # # pop those keys for generation - batch_keys_to_pop = ['index'] - non_tensor_batch_keys_to_pop = ['task_id', 'main_query', 'env_type', 'metadata', 'init_messages'] - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - - # pass global_steps to trace - gen_batch.meta_info["global_steps"] = self.global_steps - is_last_step = self.global_steps >= self.total_training_steps - - with marked_timer("step", timing_raw): - # generate a batch - print("=== + rollout step begin ===") - with marked_timer("gen", timing_raw, color="red"): - assert self.async_rollout_mode - print("=== wake up begin ===") - self.async_rollout_manager.wake_up() - print("=== wake up end ===") - # time.sleep(36000) - # from vsdb import bp - # bp("XXX") - tasks = [ - Task( - task_id=gen_batch.non_tensor_batch["task_id"][i], - main_query=gen_batch.non_tensor_batch["main_query"][i], - env_type=gen_batch.non_tensor_batch["env_type"][i], - metadata=gen_batch.non_tensor_batch["metadata"][i] - ) for i in range(len(gen_batch)) - ] - print([gen_batch.non_tensor_batch["task_id"][i] for i in range(len(gen_batch))]) - print("=" * 10 + "start fit rollout" + "=" * 10) - self.parallel_env.current_global_steps = self.global_steps - trajectories: List[CMTLinear] = self.parallel_env.rollout(tasks, mode="sample", epoch=f"train.{epoch}") - print("=" * 10 + "end fit rollout" + "=" * 10) - print("begin to convert trajectories to dataproto") - gen_batch_output = self.parallel_env.to_dataproto(trajectories) - print("end convertion") - # context_time_cost = [traj.context_time_cost for traj in trajectories] - # if context_time_cost: - # metrics.update({ - # "context_cost_avg": np.mean(context_time_cost), - # "context_cost_max": np.max(context_time_cost), - # "context_cost_min": np.min(context_time_cost), - # }) - success_rate = [traj.reward_structure.success_rate for traj in trajectories] - madness_rate = [traj.reward_structure.madness for traj in trajectories] - round_cnt = [traj.round_cnt for traj in trajectories] - metrics.update({ - "critic/round_cnt": np.mean(round_cnt), - "critic/madness_rate": np.mean(madness_rate), - "critic/success_rate": np.mean(success_rate), - "critic/real_success_rate": np.mean(trajectories[0].current_batch_success_rate), - }) - - print(f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}") - self.async_rollout_manager.sleep() - print("=== - rollout step end ===") - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - raise NotImplementedError("REMAX is not supported in GRPO yet.") - - batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) - batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output) - batch.batch["response_mask"] = compute_response_mask(batch) - - if "response_mask" not in batch.batch.keys(): - batch.batch["response_mask"] = compute_response_mask(batch) - # Balance the number of valid tokens across DP ranks. - # NOTE: This usually changes the order of data in the `batch`, - # which won't affect the advantage calculation (since it's based on uid), - # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - with marked_timer("reward", timing_raw, color="yellow"): - # compute reward model score - if self.use_rm: - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - if self.config.reward_model.launch_reward_fn_async: - raise NotImplementedError("launch_reward_fn_async is not supported in GRPO yet.") - else: - reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) - - # recompute old_log_probs - print("=== + compute log_probs begin ===") - with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_loss = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - assert not torch.isnan(entropy_loss).item(), "Entropy loss should not be NaN, something must have gone terribly wrong." - old_log_prob_metrics = {"actor/entropy": entropy_loss.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - - if "rollout_log_probs" in batch.batch.keys(): - # TODO: we may want to add diff of probs too. - from verl.utils.debug.metrics import calculate_debug_metrics - - metrics.update(calculate_debug_metrics(batch)) - - if self.use_reference_policy: - # compute reference log_prob - with marked_timer("ref", timing_raw, color="olive"): - if not self.ref_in_actor: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - else: - ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with marked_timer("values", timing_raw, color="cyan"): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with marked_timer("adv", timing_raw, color="brown"): - # we combine with rule-based rm - reward_extra_infos_dict: dict[str, list] - if self.config.reward_model.launch_reward_fn_async: - reward_tensor, reward_extra_infos_dict = ray.get(future_reward) - batch.batch["token_level_scores"] = reward_tensor - - if reward_extra_infos_dict: - batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty( - batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - - # compute advantages, executed on the driver process - - norm_adv_by_std_in_grpo = self.config.algorithm.get( - "norm_adv_by_std_in_grpo", True - ) # GRPO adv normalization factor - - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.astune.rollout.num_repeat, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - config=self.config.algorithm, - ) - - # update critic - if self.use_critic: - with marked_timer("update_critic", timing_raw, color="pink"): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with marked_timer("update_actor", timing_raw, color="red"): - batch.meta_info["multi_turn"] = self.config.astune.rollout.multi_turn.enable - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # Log rollout generations if enabled - rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) - if rollout_data_dir: - with marked_timer("dump_rollout_generations", timing_raw, color="green"): - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - sample_gts = [ - item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) - for item in batch - ] - - if "request_id" in batch.non_tensor_batch: - reward_extra_infos_dict.setdefault( - "request_id", - batch.non_tensor_batch["request_id"].tolist(), - ) - - self._dump_generations( - inputs=inputs, - outputs=outputs, - gts=sample_gts, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=rollout_data_dir, - ) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with marked_timer("testing", timing_raw, color="green"): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. - esi_close_to_expiration = should_save_ckpt_esi( - max_steps_duration=self.max_steps_duration, - redundant_time=self.config.trainer.esi_redundant_time, - ) - # Check if the conditions for saving a checkpoint are met. - # The conditions include a mandatory condition (1) and - # one of the following optional conditions (2/3/4): - # 1. The save frequency is set to a positive value. - # 2. It's the last training step. - # 3. The current step number is a multiple of the save frequency. - # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. - if self.config.trainer.save_freq > 0 and ( - is_last_step - or self.global_steps % self.config.trainer.save_freq == 0 - or esi_close_to_expiration - ): - if esi_close_to_expiration: - print("Force saving checkpoint: ESI instance expiration approaching.") - with marked_timer("save_checkpoint", timing_raw, color="green"): - self._save_checkpoint() - - with marked_timer("stop_profile", timing_raw): - next_step_profile = ( - self.global_steps + 1 in self.config.trainer.profile_steps - if self.config.trainer.profile_steps is not None - else False - ) - self._stop_profiling( - curr_step_profile and not next_step_profile - if self.config.trainer.profile_continuous_steps - else curr_step_profile - ) - prev_step_profile = curr_step_profile - curr_step_profile = next_step_profile - - steps_duration = timing_raw["step"] - self.max_steps_duration = max(self.max_steps_duration, steps_duration) - - # training metrics - metrics.update( - { - "training/global_step": self.global_steps, - "training/epoch": epoch, - } - ) - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - # TODO: implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) - - # this is experimental and may be changed/removed in the future in favor of a general-purpose one - if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): - self.train_dataloader.sampler.update(batch=batch) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - progress_bar.update(1) - self.global_steps += 1 - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - # this is experimental and may be changed/removed in the future - # in favor of a general-purpose data buffer pool - if hasattr(self.train_dataset, "on_batch_end"): - # The dataset may be changed after each training batch - self.train_dataset.on_batch_end(batch=batch) - - def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): - """ - Evaluate a dataset by running rollouts and computing task completion metrics. - - Args: - target_dataset: The dataset to evaluate - target_dataset_name: Name for logging purposes - mode: Evaluation mode ("sample" or "validate") - epoch: Current epoch for logging - - Returns: - Tuple of (cmts, tasks) containing trajectory results and task definitions - """ - pass_n = self.config.trainer.eval_pass_n - # if pass_n == 1: - # return self.eval_dataset_legacy(target_dataset, target_dataset_name, mode, epoch) - - tasks = [] - for _ in range(pass_n): - tasks += [ - task for task in target_dataset - ] - - cmts = self.parallel_env.rollout(tasks=tasks, mode=mode, epoch=epoch) # "sample" or "validate" - task_results = {} - for _cmt in cmts: - reward = _cmt.reward_structure.raw_reward - task_id = _cmt.task_id - if task_id not in task_results: - task_results[task_id] = {} - task_results[task_id]['reward_arr'] = [] - task_results[task_id]['tag_arr'] = [] - if reward >= 1: - _cmt.tag = "success" - elif reward == 0: - _cmt.tag = "failure" - else: - _cmt.tag = "half_success" - task_results[task_id]['tag_arr'] += [_cmt.tag] - task_results[task_id]['reward_arr'] += [_cmt.reward_structure.raw_reward] - task_results[task_id]['scenario'] = task_id.split('_')[0] - - task_scenario = [task_id.split('_')[0] for task_id in task_results.keys()] - set_scenarios = set(task_scenario) - num_scenarios = len(set_scenarios) - - repeated_success_tasks = 0 - num_all_success_tasks = 0 # n 次实验中全部success的任务数 - num_pass_n_tasks = 0 # n 次实验中至少有一次success的任务数 - for task_id, task_outcomes in task_results.items(): - # 计算 num_all_success_tasks # n 次实验中全部success的任务数 - # 计算 num_pass_n_tasks # n 次实验中至少有一次success的任务数 - assert len(task_outcomes['tag_arr']) == pass_n - if all(tag == "success" for tag in task_outcomes['tag_arr']): - num_all_success_tasks += 1 - if any(tag == "success" for tag in task_outcomes['tag_arr']): - num_pass_n_tasks += 1 - repeated_success_tasks += task_outcomes['tag_arr'].count("success") - - num_all_success_scenarios = 0 # 如果一个 scenario 的所有 task 都在 n 次实验中全部 success,则 num_all_success_scenarios +1 - num_pass_n_scenarios = 0 # 如果一个 scenario 的所有 task 都在 n 次实验中至少有一次 success,则 num_pass_n_scenarios +1 - repeated_num_pass_1_scenarios = 0 # 按顺序排列,如果一个 scenario 的所有 task 都在第 x 次实验中 success,则 repeated_num_pass_1_scenarios +1 - for scenario in set_scenarios: - scenario_task_results = {task_id: task_outcomes for task_id, task_outcomes in task_results.items() if task_outcomes['scenario'] == scenario} - # num_all_success_scenarios - if all(all(tag == "success" for tag in task_outcomes['tag_arr']) for task_outcomes in scenario_task_results.values()): - num_all_success_scenarios += 1 - # num_pass_n_scenarios - if all(any(tag == "success" for tag in task_outcomes['tag_arr']) for task_outcomes in scenario_task_results.values()): - num_pass_n_scenarios += 1 - # num_pass_1_scenarios - for x in range(pass_n): - if all(task_outcomes['tag_arr'][x]=='success' for task_outcomes in scenario_task_results.values()): - repeated_num_pass_1_scenarios += 1 - - # 记录日志 - task_scenario_for_cmts = [_cmt.task_id.split('_')[0] for _cmt in cmts] - for _cmt, scenario in zip(cmts, task_scenario_for_cmts): - task_outcome = _cmt.tag - selectors = [scenario, _cmt.task_id, task_outcome] - _cmt.generate_log() - reward = _cmt.reward_structure.raw_reward - - - rewards = [ _cmt.reward_structure.raw_reward for _cmt in cmts ] - num_tasks = len(task_results) - assert num_tasks == len(cmts) // pass_n - - val_metrics = { - "target dataset name": target_dataset_name, - "pass_n": pass_n, - - "total_tasks": len(task_results), - "num_all_success_tasks": num_all_success_tasks, - f"num_pass_n_tasks(pass@{pass_n})": num_pass_n_tasks, - - "num_scenarios": num_scenarios, - "num_all_success_scenarios": num_all_success_scenarios, - f"num_pass_n_scenarios(pass@{pass_n})": num_pass_n_scenarios, - - "TGC@1": repeated_success_tasks / (num_tasks * pass_n), - f"TGC@{pass_n}": num_pass_n_tasks / num_tasks, - f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks, - f"SGC@1": repeated_num_pass_1_scenarios / (num_scenarios * pass_n), - f"SGC@{pass_n}": num_pass_n_scenarios / num_scenarios, - f"SGC@{pass_n}-all-pass": num_all_success_scenarios / num_scenarios, - "mean_reward": sum(rewards) / len(rewards) if rewards else 0, - } - print_dict(val_metrics, narrow=True, header=target_dataset_name, mod="evaluation") - - self.tracking_logger.log(data=val_metrics, step=self.global_steps) - - return cmts, tasks, val_metrics - - - def get_eval_dataset(self): - from astune.utils.process_dataset import create_rl_dataset - if self.config.astune.task_reader.type == 'env_service': - if self.config.astune.task_reader.env_service.env_type == "appworld": - if hasattr(self, 'main_val_dataset'): - return self.main_val_dataset, None, None - else: - from astune.task_reader.task_reader_base import TaskReaderRouter - task_reader = TaskReaderRouter(self.config) - tasks = task_reader.get_validation_tasks() - self.main_val_dataset = tasks - return self.main_val_dataset, None, None - - # elif self.config.env_service.env_type == "webshop": - # if hasattr(self, 'main_val_dataset'): - # return self.main_val_dataset, None, None - # else: - # config = self.config - # self.main_val_dataset = create_rl_dataset(config.data.val_files, config.data, self.tokenizer, processor=None, is_train=False, env_config=config.env_service) - # # self.test_normal_dataset = create_rl_dataset(config.data.val_files, config.data, self.tokenizer, processor=None, is_train=False, env_config=config.env_service) - # if config.data.fast_eval: # 使用一个小测试集 - # self.main_val_dataset.dataframe = self.main_val_dataset.dataframe.shuffle(seed=42).select(range(100)) # limit to 100 samples - # return self.main_val_dataset, None, None - # else: - # self.main_val_dataset.dataframe = self.main_val_dataset.dataframe.shuffle(seed=42).select(range(500)) # limit to 100 samples - # return self.main_val_dataset, None, None - - # elif self.config.env_service.env_type == "crafters": - # if hasattr(self, 'main_val_dataset'): - # return self.main_val_dataset, None, None - # else: - # config = self.config - # self.main_val_dataset = create_rl_dataset(config.data.val_files, config.data, self.tokenizer, processor=None, is_train=False, env_config=config.env_service) - # # self.test_normal_dataset = create_rl_dataset(config.data.val_files, config.data, self.tokenizer, processor=None, is_train=False, env_config=config.env_service) - # self.main_val_dataset.dataframe = self.main_val_dataset.dataframe.shuffle(seed=42).select(range(10)) # limit to 100 samples - # return self.main_val_dataset, None, None - - else: - raise NotImplementedError \ No newline at end of file diff --git a/astune/context_manager/cmt_agentscope.py b/astune/context_manager/cmt_agentscope.py deleted file mode 100644 index 5bce7b1f..00000000 --- a/astune/context_manager/cmt_agentscope.py +++ /dev/null @@ -1,424 +0,0 @@ -import copy -import importlib -from loguru import logger -from datetime import datetime -from astune.schema.trajectory import Reward, Trajectory -from astune.context_manager.cmt_linear import CMTLinear, ExtendedMessage -from agentscope.model import DashScopeChatModel, ChatResponse -from astune.context_manager.cmt_linear import replace_token_ids, CMTLinear -from astune.schema.trajectory import Sample, Reward -from typing import Any, Dict, List, Union, Tuple -from beast_logger import register_logger, print_dict, print_nested, NestedJsonItem, SeqItem -from astune.utils.compute_madness import compute_string_madness -from agentscope._utils._common import _json_loads_with_repair, _create_tool_from_base_model -from astune.context_manager.cmt_base_attr import INVALID_LOG_PROB_VALUE - - -import colorsys - -def adjust_color_hsl(base_color, logprob): - """ - 使用HSL颜色空间根据logprob调整颜色饱和度 - """ - # 将logprob映射到[sat_min, sat_max]的饱和度调整因子 - sat_min = 0.333 - sat_max = 1.0 - lp_min = -7 - lp_max = 0 - - if logprob <= lp_min: - saturation_factor = sat_min - elif logprob >= 0: - saturation_factor = sat_max - else: - saturation_factor = sat_min + (logprob - lp_min) / (lp_max - lp_min) * (sat_max - sat_min) - - # 将十六进制颜色转换为RGB - r = int(base_color[1:3], 16) / 255.0 - g = int(base_color[3:5], 16) / 255.0 - b = int(base_color[5:7], 16) / 255.0 - - # 转换为HSL - h, l, s = colorsys.rgb_to_hls(r, g, b) - - # 调整饱和度 - s_adjusted = s * saturation_factor - - # 转换回RGB - r_adjusted, g_adjusted, b_adjusted = colorsys.hls_to_rgb(h, l, s_adjusted) - - # 转换回十六进制 - return f"#{int(r_adjusted*255):02x}{int(g_adjusted*255):02x}{int(b_adjusted*255):02x}" - - -class BeyondAgentContextTemplate(CMTLinear): - - def __init__(self, llm_chat_fn, tokenizer, config, env_step_fn, should_interrupt_fn, generated_token_callback_fn, **kwargs): - super().__init__(config, tokenizer) - self.task_batch_index = kwargs.pop("task_batch_index") - self.task_tag = kwargs.pop("task_tag") - self.task_id = kwargs.pop("task_id") - self.dscm_ref = DashScopeChatModel(**kwargs) - self.full_context: List[ExtendedMessage] = [] - self.llm_chat_fn = llm_chat_fn - self.tokenizer = tokenizer - self.stream = False - self.config = config - self.env_step_fn = env_step_fn - self.should_interrupt_fn = should_interrupt_fn - self.generated_token_callback_fn = generated_token_callback_fn - self.context_overflow = False - self.model_name = kwargs['model_name'] - self.output_kwargs = {} - self.input_kwargs = {} - - def process_reward(self, reward_structure: Reward): - self.reward_structure = reward_structure - ext_steps = self.full_context - # # lienar 模式只有一条轨迹 - # self.reward_structure.step_reward = [ - # self.compute_step_level_reward(ext_steps=ext_steps, index=0, total_steps=1) - # ] - # print('warning: debugging') - self.reward_structure.step_reward = [ - self.compute_step_level_reward(ext_steps=ext_steps, index=i, total_steps=len(self.grouped_steps)) for i in range(len(self.grouped_steps)) - ] - - - def generate_log(self, task_id = None, global_step="NA"): - task_id = self.task_id - nested_items_print_buffer = {} - for index, ext_steps in enumerate(self.grouped_steps): - from vsdb import bp - bp("LLL") - - cmt_tokenized = self.tokenize_steps(ext_steps=ext_steps, index=index, total_steps=len(self.grouped_steps)) - text_arr = [self.tokenizer.decode(t) for t in cmt_tokenized["input_ids"]] - input_id_arr = [str(t) for t in cmt_tokenized["input_ids"]] - # loss_mask_color_arr = ["#09ABCF" if mask==1 else "#D98510" for mask in cmt_tokenized["loss_mask"]] - logprobs = [INVALID_LOG_PROB_VALUE] * len(cmt_tokenized["prompt_ids"]) + cmt_tokenized["response_logprobs"] - # 创建调整后的颜色数组 - loss_mask_color_abl_arr = [ - adjust_color_hsl("#09ABCF", logprob) if mask == 1 - else adjust_color_hsl("#D98510", logprob) - for mask, logprob in zip(cmt_tokenized["loss_mask"], logprobs) - ] - logprob_text_arr = [f"{logprob:.4f}" if logprob != INVALID_LOG_PROB_VALUE else "N/A" for logprob in logprobs] - - buffer = { - "text_arr": text_arr, - "logprob_arr": logprob_text_arr, - "input_id_arr": input_id_arr, - "loss_mask_color_arr": loss_mask_color_abl_arr, - } - raw_reward = self.reward_structure.raw_reward - step_reward:float = self.reward_structure.step_reward[index] - try: - step_advantage = self.reward_structure.step_advantage[index] - step_advantage_simple = self.reward_structure.step_advantage_simple[index] - except: - step_advantage = 0.0 - step_advantage_simple = 0.0 - task_outcome = str(self.reward_structure.success_rate) - selectors = [task_id, task_outcome, str(index)] - len_prompt_ids = len(cmt_tokenized["prompt_ids"]) - len_response_ids = len(cmt_tokenized["response_ids"]) - len_input_ids = len(cmt_tokenized["input_ids"]) - assert len_prompt_ids + len_response_ids == len_input_ids, "len_prompt_ids + len_response_ids should equal to len_input_ids" - nested_items_print_buffer[f".".join(selectors)] = NestedJsonItem( - item_id=f"item", # type: ignore - outcome=task_outcome, # type: ignore - len_prompt_ids=len_prompt_ids, # type: ignore - len_response_ids=len_response_ids, # type: ignore - len_input_ids=len_input_ids, # type: ignore - raw_reward=f"{float(raw_reward):.3f}", # type: ignore - step_reward=f"{float(step_reward):.3f}", # type: ignore - step_advantage=f"{float(step_advantage):.3f}", # type: ignore - step_advantage_simple=f"{float(step_advantage_simple):.3f}", # type: ignore - content=SeqItem( - text = buffer['text_arr'], # 文本 - title = buffer['logprob_arr'], # 鼠标悬浮文本 - count = buffer['input_id_arr'], # 高亮文本 - color = buffer['loss_mask_color_arr'] # 颜色 - ) - ) - print_nested(nested_items_print_buffer, - main_content="This is the main content of the nested JSON", - header=f"[{global_step}] Task {task_id} (Reward {float(step_reward):.3f})", # type: ignore - mod="rollout", - narrow=False, - attach="copy this" # type: ignore - ) - - def group_merge(self): - def can_merge_steps(source_step: List[ExtendedMessage], target_step: List[ExtendedMessage]) -> bool: - # if `source_step` has more messages than `target_step` - # and if `source_step` and `target_step` share same token_arr in [0:len(target_step)] - # even if the authors are different, we can still merge them - can_merge = False - # compare_level = 'token' # 严格按照token对比 - compare_level = 'text' # 对比文本,这样子会导致有些token不一样但是文本一样的情况也能merge,更宽松一些,收益很大,代价未知 - if len(source_step) >= len(target_step): - all_msg_match = True - for i in range(len(target_step)): - if compare_level == 'text': - same = source_step[i].content_for_future == target_step[i].content_for_future - elif compare_level == 'token': - same = source_step[i].token_arr == target_step[i].token_arr - else: - raise NotImplementedError - if not same: - all_msg_match = False - break - if all_msg_match: - can_merge = True - return can_merge - - def toggle_author(source_step: List[ExtendedMessage], target_step: List[ExtendedMessage]) -> List[ExtendedMessage]: - # if any message in `target_step` is author == 'llm', but same-index message in `source_step` is author != 'llm' - # change source_step's message author to 'llm' - for i in range(len(target_step)): - if target_step[i].author == 'llm' and source_step[i].author != 'llm': - source_step[i].author = target_step[i].author - source_step[i].token_arr = target_step[i].token_arr - source_step[i].token_logprob_arr = target_step[i].token_logprob_arr - assert source_step[i].need_training - return source_step - - absorbed_step_indices = [] - reversed_grouped_steps = list(reversed(self.grouped_steps)) - for i in range(len(reversed_grouped_steps)): - if i in absorbed_step_indices: - continue - # check whether [i, len(reversed_grouped_steps)-1] can be merged - for j in range(i+1, len(reversed_grouped_steps)): - if j in absorbed_step_indices: - continue - source_step = reversed_grouped_steps[i] - target_step = reversed_grouped_steps[j] - if can_merge_steps(source_step, target_step): - source_step = toggle_author(source_step, target_step) - reversed_grouped_steps[i] = source_step - absorbed_step_indices += [j] - - # reverse back and exclude absorbed steps - reversed_grouped_steps_clean = [] - for i in range(len(reversed_grouped_steps)): - if i not in absorbed_step_indices: - reversed_grouped_steps_clean.append(reversed_grouped_steps[i]) - self.grouped_steps = list(reversed(reversed_grouped_steps_clean)) - - return self.grouped_steps - - def group_tokenize(self): - return self.group_tokenize_multi_group() - - def get_inc(self, text_frag_from, text_frag_to): - """ - Get the incremental token array from text_frag_from to text_frag_to. - """ - tokenizer_output = self.tokenizer(text_frag_from, return_tensors="pt", padding=False) - tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist() - token_ids_acc = tokenizer_input_ids - - tokenizer_output = self.tokenizer(text_frag_to, return_tensors="pt", padding=False) - input_ids = tokenizer_output["input_ids"][0].tolist() - input_id_increment = input_ids[len(token_ids_acc):] # get the new tokens added in this step - overlap_length = 0 - for i in range(len(token_ids_acc)): - if i < len(token_ids_acc) and input_ids[i] == token_ids_acc[i]: overlap_length += 1 - else: break - msg = f"previous token length: {len(token_ids_acc)}, overlap token length: {(overlap_length)}, increment token length: {len(input_id_increment)}" - # print(msg) - return input_id_increment, msg - - def check_context_token_num_safe(self, messages: List[dict]) -> Tuple[bool, str]: - prompt_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) - max_response_length = self.config.astune.rollout.max_response_length_in_one_turn - max_model_len: int = self.config.astune.rollout.max_model_len - self.max_seq_length: int = max_model_len - max_response_length - if self.should_interrupt_fn(): - return False, "externally_interrupted" - if self.already_mad_flag and self.config.astune.rollout.agent_madness_termination: - return False, "already_mad" - if length < self.max_seq_length: - return True, f"safe[{length} < {max_model_len} - {max_response_length}]" - else: - return False, "token_overflow" - - -class BeyondAgentLmProxy(BeyondAgentContextTemplate): - - async def execute_model_proxy(self, messages: List[dict], tools: List[dict]=[], tool_choice: str = "auto", **kwargs) -> dict: - # load messages into `self.full_context` - self.full_context = [] - - for i, msg in enumerate(messages): - if not isinstance(msg['content'], str): - continue - if msg['role'] not in ['user', 'assistant', 'system']: - continue - if msg['role'] == 'system': - author = 'initialization' - else: - # mask everything - author = 'env' - self.full_context += [ - ExtendedMessage( - author=author, - role=msg['role'], - content=msg['content'], - tokenizer=self.tokenizer, - token_generator="auto", - ) - ] - - # execute llm policy - messages = self.to_role_content(self.full_context) - # 4. ⚠️ check token overflow - is_safe, info = self.check_context_token_num_safe(messages) - custom_sampling_params = {} - if not is_safe: - logger.warning(f"[{info}] detected. Current token count exceeds the limit.") - self.context_overflow = True - return ChatResponse( - content = [{'type': 'text', 'text': 'beyondagent_proxy:[context_overflow]'}] - ) - - # from vsdb import bp - # bp("INF") - llm_output = self.llm_chat_fn(messages, custom_sampling_params) - - # compute_string_madness - if not self.already_mad_flag: - if compute_string_madness(completion=llm_output['content'], checklist=self.config.astune.rollout.compute_madness_checklist) < 0.0: - self.already_mad_flag = True - - # dummy response for now - token_generator = "manual" if 'tokens' in llm_output else "auto" - llm_ext_msg = ExtendedMessage( - author="llm", - role="assistant", - content=llm_output['content'], - token_generator=token_generator, - tokenizer=self.tokenizer, - ) - - from vsdb import bp - bp("LOG") - - if token_generator == "manual": - input_msg_ref = copy.deepcopy(messages) - token_arr_method2, token_logprob_arr = self.get_token_inc_from_vllm_response(input_msg_ref, llm_output) - assert len(token_arr_method2) <= self.config.astune.rollout.max_response_length_in_one_turn, f"Generated token length {len(token_arr_method2)} exceeds max_response_len {self.config.astune.rollout.max_response_length_in_one_turn}" - llm_ext_msg.token_arr = token_arr_method2 - llm_ext_msg.token_logprob_arr = token_logprob_arr - self.generated_token_callback_fn(llm_ext_msg.token_arr) - - # take snapshot of current timeline - if is_safe: - self.full_context += [ - llm_ext_msg - ] - prompt_text = self.tokenizer.apply_chat_template(self.to_role_content(self.full_context), tokenize=False, add_generation_prompt=True) - length = len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) - if length >= self.config.astune.rollout.max_model_len: - raise RuntimeError(f"Unexpected token overflow after adding LLM response. Full context length {length}, before gen info {info}, generated token length {len(llm_ext_msg.token_arr)}") - - self.grouped_steps += [copy.deepcopy(self.full_context)] - # return response - return ChatResponse( - content = [{'type': 'text', 'text': llm_ext_msg.content_for_future}] - ) - - - -class BeyondAgentProxy(BeyondAgentLmProxy): - """ - A proxy class that bridge: - - environment - - reward - - policy llm model - """ - - - async def __call__( - self, - messages: list[dict[str, Any]], - tools: list[dict] | None = None, - tool_choice = None, - structured_model = None, - **kwargs: Any, - ): - import dashscope - - # For qvq and qwen-vl models, the content field cannot be `None` or - # `[{"text": None}]`, so we need to convert it to an empty list. - if self.model_name.startswith("qvq") or "-vl" in self.model_name: - raise NotImplementedError("Not implemented for qvq and qwen-vl models yet.") - - kwargs = { - "messages": messages, - "model": self.model_name, - "stream": self.stream, - **self.dscm_ref.generate_kwargs, - **kwargs, - "result_format": "message", - # In agentscope, the `incremental_output` must be `True` when - # `self.stream` is True - "incremental_output": self.stream, - } - - if tools: - kwargs["tools"] = self.dscm_ref._format_tools_json_schemas(tools) - - if tool_choice: - self.dscm_ref._validate_tool_choice(tool_choice, tools) - kwargs["tool_choice"] = self.dscm_ref._format_tool_choice(tool_choice) - - if ( - self.dscm_ref.enable_thinking is not None - and "enable_thinking" not in kwargs - ): - kwargs["enable_thinking"] = self.dscm_ref.enable_thinking - - if structured_model: - if tools or tool_choice: - logger.warning( - "structured_model is provided. Both 'tools' and " - "'tool_choice' parameters will be overridden and " - "ignored. The model will only perform structured output " - "generation without calling any other tools.", - ) - format_tool = _create_tool_from_base_model(structured_model) - kwargs["tools"] = self.dscm_ref._format_tools_json_schemas( - [format_tool], - ) - kwargs["tool_choice"] = self.dscm_ref._format_tool_choice( - format_tool["function"]["name"], - ) - - response = await self.execute_model_proxy( - api_key=self.dscm_ref.api_key, - **kwargs, - ) - return response - - def update_agentscope_input_dictionary(self, **kwargs): - self.input_kwargs.update(kwargs) - - def get_agentscope_input_dictionary(self): - return self.input_kwargs - - def update_judge_input_dictionary(self, **kwargs): - self.output_kwargs.update(kwargs) - - def get_judge_input_dictionary(self): - return self.output_kwargs - - def get_judge(self): - judge_protocol = self.config.astune.task_judge.judge_protocol - module_, class_ = judge_protocol.split('->') - protocol_cls = getattr(importlib.import_module(module_), class_) - return protocol_cls(self.config) # type: ignore diff --git a/astune/context_manager/cmt_base_attr.py b/astune/context_manager/cmt_base_attr.py deleted file mode 100644 index 0ea8f3bd..00000000 --- a/astune/context_manager/cmt_base_attr.py +++ /dev/null @@ -1,182 +0,0 @@ -from typing import List, Union, Tuple -from astune.schema.trajectory import Reward -import uuid - - -def find_sublist_indices(large_list, small_list, reverse=False): - small_len = len(small_list) - if reverse: - for i in reversed(range(len(large_list) - small_len + 1)): - if large_list[i: i+small_len] == small_list: - return i - for i in range(len(large_list) - small_len + 1): - if large_list[i: i+small_len] == small_list: - return i - return -1 - -INVALID_LOG_PROB_VALUE = 0.0 - -def replace_token_ids(place_holder, replace_with, begin, end, raw_logprob) -> Tuple[List[int], List[int]]: - _begin_index = find_sublist_indices(place_holder, begin) + len(begin) - _end_index = find_sublist_indices(place_holder, end, reverse=True) - - if replace_with[-len(end):] == end: # remove end token - replace_with = replace_with[:-len(end)] - raw_logprob = raw_logprob[:-len(end)] - if replace_with[:len(begin)] == begin: # remove begin token - replace_with = replace_with[len(begin):] - raw_logprob = raw_logprob[len(begin):] - - final = place_holder[:_begin_index] + replace_with + place_holder[_end_index:] - final_logprob = [INVALID_LOG_PROB_VALUE] * _begin_index + raw_logprob + [INVALID_LOG_PROB_VALUE] * (len(place_holder) - _end_index) - return final, final_logprob - -class ExtendedMessage: - - def __init__( - self, - author, - role="assistant", - content="", - token_arr=[], - token_begin_index=-1, - token_end_index=-1, - clip=False, - clip_token_limit=8192, - tokenizer=None, - token_generator="manual", - build_from_uuid="", - token_logprob_arr=[], - ): - self.author = author - self.role = role - self.content = content - self.token_arr = token_arr - self.token_logprob_arr = token_logprob_arr - self.token_begin_index = token_begin_index - self.token_end_index = token_end_index - self.invalid_log_prob_value = INVALID_LOG_PROB_VALUE - # use property to ensure content is safe before use - self._content_for_future = "" - self._info = "" - self.clip = clip - self.uuid = uuid.uuid4().hex - self.build_from_uuid = build_from_uuid - - if not clip: - self.generate_content_for_future(tokenizer=None, clip=False) - else: - self.generate_content_for_future(tokenizer=tokenizer, clip=True, clip_token_limit=clip_token_limit) - self.eos_token_id = tokenizer.eos_token_id - if token_generator == 'auto': - dummy_msg = [ {"role": "assistant", "content": "dummy text"} ] - try: - text_frag_to = tokenizer.apply_chat_template(dummy_msg + [ {"role": self.role, "content": self.content_for_future} ], tokenize=False) - except Exception as e: - raise ValueError(f"Cannot tokenize {self.role} --- {self.content_for_future}, \n\n Error: {e}") - self.token_arr, _ = self.get_inc_simple( - text_frag_from=tokenizer.apply_chat_template(dummy_msg, tokenize=False), - text_frag_to=text_frag_to, - tokenizer=tokenizer - ) - - @property - def content_for_future(self): - if self._content_for_future == "": raise ValueError("content_for_future is not set, or previous llm output is empty!") - return self._content_for_future - - - @property - def need_training(self): - NEED_TRAIN_AUTHORS = ["llm"] - NON_TRAIN_AUTHORS = ["env", "initialization", "user", "memory", "llm(do_not_train)"] - assert (self.author in NEED_TRAIN_AUTHORS) or (self.author in NON_TRAIN_AUTHORS) or (self.author.endswith('(discard)')), f"author {self.author} is not identified" - return (self.author in NEED_TRAIN_AUTHORS) - - - def generate_content_for_future(self, tokenizer, clip, clip_token_limit=-1): - _content: str = self.content - if clip: - assert clip_token_limit > 0, "clip_token_limit must be set when clip is True" - n_token = len(tokenizer(_content, return_tensors="pt", padding=False)["input_ids"][0]) - if n_token > clip_token_limit: - # 8000 > 4000 - n_char = len(_content) # 10,000 - eps = 100 # token - preserve_percent = (clip_token_limit - eps) / n_token # 3900 / 8000 - n_char_to_preserve = int(n_char * preserve_percent) - _content = _content[:n_char_to_preserve] + "... truncate ..." - self._content_for_future = _content - - - def get_loss_mask(self, blackout_token_combo): - def blackout_specific_token_ids_first_encounter(mask, arr, token_ids): - index = find_sublist_indices(arr, token_ids, reverse=False) - if index >= 0: - for i in range(index, index+len(token_ids)): mask[i] = 0 - return mask - - def blackout_everything_after_eos_but_keep_eos(mask, token_arr, eos_token_id): - eos_position = token_arr.index(eos_token_id) if eos_token_id in token_arr else -1 - if eos_position != -1: - for i in range(eos_position + 1, len(mask)): - mask[i] = 0 - return mask - - if self.need_training: - msg_token_mask = [1] * len(self.token_arr) - msg_token_mask = blackout_specific_token_ids_first_encounter(msg_token_mask, self.token_arr, blackout_token_combo) - msg_token_mask = blackout_everything_after_eos_but_keep_eos(mask=msg_token_mask, token_arr=self.token_arr, eos_token_id=self.eos_token_id) - return msg_token_mask - else: - msg_token_mask = [0] * len(self.token_arr) - return msg_token_mask - - def get_inc_simple(self, text_frag_from, text_frag_to, tokenizer): - """ - Get the incremental token array from text_frag_from to text_frag_to. - """ - tokenizer_output = tokenizer(text_frag_from, return_tensors="pt", padding=False) - tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist() - token_ids_acc = tokenizer_input_ids - - tokenizer_output = tokenizer(text_frag_to, return_tensors="pt", padding=False) - input_ids = tokenizer_output["input_ids"][0].tolist() - input_id_increment = input_ids[len(token_ids_acc):] # get the new tokens added in this step - overlap_length = 0 - for i in range(len(token_ids_acc)): - if i < len(token_ids_acc) and input_ids[i] == token_ids_acc[i]: overlap_length += 1 - else: break - msg = f"previous token length: {len(token_ids_acc)}, overlap token length: {(overlap_length)}, increment token length: {len(input_id_increment)}" - # print(msg) - return input_id_increment, msg - -class CMTBaseAttr(object): - - def __init__(self, config, tokenizer): - self.task_batch_index = 'undefined' - self.task_tag = 'undefined' - self.task_id = 'undefined' - self.config = config - self.tokenizer = tokenizer - self.full_context: List[ExtendedMessage] = [] - self.grouped_steps: List[List[ExtendedMessage]] = [] - self.current_context_status = "" - max_response_length = self.config.astune.rollout.max_response_length_in_one_turn - max_model_len: int = self.config.astune.rollout.max_model_len - self.max_seq_length: int = max_model_len - max_response_length - self.max_env_output_length: int = self.config.astune.rollout.max_env_len - self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") - self.generated_token_cnt = 0 - - self.terminal_rewards_dict = {} - self.discarded = False - self.is_terminated = False - self.reward_structure: Union[Reward, None] = None - self.context_time_cost = 0 - self.tag = "" - self.current_batch_success_rate:float = -1.0 - self.already_mad_flag = False - self.round_cnt = 0 - - assert self.config.astune.data.max_prompt_length + self.config.astune.data.max_response_length <= max_model_len \ No newline at end of file diff --git a/astune/context_manager/cmt_context_aware.py b/astune/context_manager/cmt_context_aware.py deleted file mode 100644 index a2da0825..00000000 --- a/astune/context_manager/cmt_context_aware.py +++ /dev/null @@ -1,171 +0,0 @@ -import copy -from typing import List -from beast_logger import print_listofdict -from astune.context_manager.cmt_linear_think import ExtendedMessage, CMTLinear, LinearThinkCMT -from beast_logger import print_dict, print_nested, NestedJsonItem, SeqItem - - -class SelfContextAwareCMT(LinearThinkCMT): - """ - A non-linear context manager template that handles the conversation flow between LLM and environment. - """ - - def __init__(self, config, tokenizer, llm_chat_fn): - self.llm_chat_fn = llm_chat_fn - self.latest_env_response_id = "" - self.latest_env_response_content = "" - self.console_debug_mode = False - self.force_think = config.astune.rollout.force_think - super().__init__(config, tokenizer) - - - def post_tag_env_message_context(self, content, turn, is_last) -> str: - from textwrap import dedent - assert 0 <= turn < 999, "turn 必须在 [0, 999) 范围内" - turn_id = f"{turn:03d}" # 等效:str(turn).zfill(3) - self.latest_env_response_id = f"ER{turn_id}" - self.latest_env_response_content = content.strip() - content = dedent(f""" - [Environment Response, id="ER{turn_id}"] - --- - """).strip() + content.strip() - if is_last and self.force_think: - content += "\n\nAdditional requirements: \n- You must think step by step before your next action, and you must use ... to wrap your thinking process before finally produce your answer with \\box{}. (Put \\box{} outside ...)." - - return content - - def post_tag_init_message_context(self, content, is_last) -> str: - if is_last: - content = content.strip() # + "\nSome additional requirements for last msg \n" - if is_last and self.force_think: - content += "\n\nAdditional requirements: \n- You must think step by step before your next action, and you must use ... to wrap your thinking process before finally produce your answer with \\box{}. (Put \\box{} outside ...)." - return content.strip() - - def prepare_next_llm_context(self): - self.latest_llm_interaction_socket = [] - - # first we get all previous context (non-deprecated context) - # get `init_message -> user -> llm -> user -> llm`` or `init_message -> llm -> user -> llm -> user`` - self.latest_llm_interaction_socket = self.filter_context_via_authors(["initialization", "llm", "env"]) - - - env_turn = 1 - for index, ext_msg in enumerate(list(self.latest_llm_interaction_socket)): - - is_last = (index == len(self.latest_llm_interaction_socket) - 1) - # 根据消息类型进行处理 - if ext_msg.author == "llm": - # 如果是以往的llm消息,去掉think标签 - import re - new_ext_msg_content = re.sub(r'.*?', '', ext_msg.content, flags=re.DOTALL).strip() - new_ext_msg_content = new_ext_msg_content.replace("", "") - new_ext_msg_content = new_ext_msg_content.replace("", "") - - assert ext_msg.author == "llm" - author_override = "llm(do_not_train)" - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=author_override, - role=ext_msg.role, - content=new_ext_msg_content, - token_generator='auto', - tokenizer=self.tokenizer, - ) - - # process env message - elif ext_msg.author == "env": - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=self.post_tag_env_message_context(content=ext_msg.content_for_future, turn=env_turn, is_last=is_last), - token_generator='auto', - tokenizer=self.tokenizer, - ) - env_turn += 1 - - elif ext_msg.author in ["initialization"]: - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=self.post_tag_init_message_context(content=ext_msg.content_for_future, is_last=is_last), - token_generator='auto', - tokenizer=self.tokenizer, - ) - - else: - raise RuntimeError(f"Unknown author {ext_msg.author} in latest_llm_interaction_socket") - - dict_context = self.to_role_content(self.latest_llm_interaction_socket) - return dict_context - - - def save_init_input(self, init_input_arr:list, add_nothink): - super().save_init_input(init_input_arr, add_nothink) - return - - - def after_save_llm_output(self, llm_output, this_interaction): - if not self.latest_env_response_id: - return - self.latest_llm_interaction_socket_additional = copy.deepcopy(this_interaction) - self.latest_llm_interaction_socket_additional += [ExtendedMessage( - author='user', - role='user', - content=f"""Now your new task is to inspect `Environment Response` {self.latest_env_response_id} and then extract paragraphs that may be useful information in last action or in the future.""" - """For example, if the original Response contain paragraph ABCDEF and only paragraph ABCF maybe useful, you should answer me by copying paragraph ABCF (wrapped them between ```).""" - """Do not give up details easily, try your best to find useful information. When necessary, you can preserve everything.""", - token_generator='auto', - tokenizer=self.tokenizer, - )] - dict_context = self.to_role_content(self.latest_llm_interaction_socket_additional) - llm_output = self.llm_chat_fn(dict_context, request_id="") - self.latest_llm_interaction_socket_additional += [self.save_llm_output_do_not_register_full_context(llm_output, dict_context)] - this_interaction = copy.deepcopy(self.latest_llm_interaction_socket_additional) - self.grouped_steps += [this_interaction] - - - if self.console_debug_mode: - print_listofdict( - dict_context + - [{'role': 'llm_latest', 'content': llm_output['content']}] - , mod='c') - try: - llm_output_content = llm_output['content'] = llm_output['content'].strip() - if llm_output_content.count("```") == 2: - extracted_content: str = llm_output_content.split("```")[1].strip() - else: - raise RuntimeError(f"Cannot find ``` in llm_output content: {llm_output_content}") - - # override future full_context - assert self.latest_env_response_content != '' - replace_success = self.replace_full_context_item(match_content=self.latest_env_response_content, new_content=extracted_content) - if not replace_success: - raise RuntimeError(f"Cannot find {self.latest_env_response_id} in full_context") - - except Exception as e: - print(f"Error processing llm_output") - return - - def replace_full_context_item(self, match_content: str, new_content: str): - success = False - for index in range(len(self.full_context)): - ext_msg = self.full_context[index] - if match_content in ext_msg.content_for_future: - success = True - self.full_context[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=new_content, - token_generator='auto', - tokenizer=self.tokenizer, - ) - # print_dict({match_content: new_content}) - return success - return success - - def save_llm_output(self, llm_output, input_msg_ref): - ext_msg = CMTLinear.save_llm_output(self, llm_output, input_msg_ref) - this_interaction = copy.deepcopy(self.latest_llm_interaction_socket + [ext_msg]) - self.grouped_steps += [this_interaction] - self.after_save_llm_output(llm_output, this_interaction) - self.latest_llm_interaction_socket = [] - return diff --git a/astune/context_manager/cmt_context_clip.py b/astune/context_manager/cmt_context_clip.py deleted file mode 100644 index 4b01d427..00000000 --- a/astune/context_manager/cmt_context_clip.py +++ /dev/null @@ -1,365 +0,0 @@ -import copy -import re -import json -import random -import time -from typing import List, Callable -from beast_logger import print_dict, print_listofdict -from astune.context_manager.cmt_linear_think import ExtendedMessage, CMTLinear, LinearThinkCMT -from astune.context_manager.cmt_foreign_llm import construct_alien_llm_chat_fn -from textwrap import dedent -from openai import OpenAI -from loguru import logger - - -class SelfContextClipCMT(LinearThinkCMT): - """ - A non-linear context manager template that handles the conversation flow between LLM and environment. - """ - - def __init__(self, config, tokenizer, llm_chat_fn): - self.llm_chat_fn = llm_chat_fn - self.alien_llm_chat_fn: Callable = construct_alien_llm_chat_fn(config, config.actor_rollout_ref.rollout) - self.latest_env_response_id = "" - self.latest_env_response_content = "" - self.console_debug_mode = False - self.force_think = config.astune.rollout.force_think - self.env_action_preference = config.astune.task_reader.env_service.env_action_preference - self.train_sp_action = config.astune.context_manager.auto_context_cm.train_sp_action - self.clipped_before = False - if self.env_action_preference == "box": - self.force_think_prompt = dedent(""" - Additional requirements: Think before action! You must think step by step before your next action, and you must use ... to wrap your thinking process before finally produce your answer with \\box{}. - Your thought (...) should be as short and concise as possible. - For example: - ...your thinking process... - \\box{...your final answer...} - """) - elif self.env_action_preference == "code": - self.force_think_prompt = dedent(""" - Additional requirements: Think before action! You must think step by step before your next action, and you must use ... to wrap your thinking process before finally produce the next-step action. - Your thought (...) should be as short and concise as possible. - For example: - ...your thinking process... - ```python - # your action here - ``` - """) - - super().__init__(config, tokenizer) - - def post_tag_init_message_context(self, content, is_last) -> str: - if is_last: - content = content.strip() # + "\nSome additional requirements for last msg \n" - if is_last and self.force_think: - content += self.force_think_prompt - return content.strip() - - def post_tag_env_message_context(self, content, turn, is_last) -> str: - from textwrap import dedent - assert 0 <= turn < 999, "turn 必须在 [0, 999) 范围内" - turn_id = f"{turn:03d}" - self.latest_env_response_id = f"ER{turn_id}" - self.latest_env_response_content = content.strip() - content = dedent(f""" - [Environment Response, id=ER{turn_id}] - --- - """).strip() + '\n' + content.strip() - if is_last and self.force_think: - content += self.force_think_prompt - return content - - def post_tag_llm_message_context(self, content, turn, is_last) -> str: - from textwrap import dedent - assert not is_last, "llm message should never be last" - assert 0 <= turn < 999, "turn 必须在 [0, 999) 范围内" - turn_id = f"{turn:03d}" - content = dedent(f""" - [Assistant Response, id=AR{turn_id}] - --- - """).strip() + '\n' + content.strip() - return content - - def strip_think_tags(self, text: str) -> str: - new_ext_msg_content = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() - new_ext_msg_content = new_ext_msg_content.replace("", "") - new_ext_msg_content = new_ext_msg_content.replace("", "") - return new_ext_msg_content - - def prepare_next_llm_context(self): - self.latest_llm_interaction_socket = [] - - # first we get all previous context (non-deprecated context) - # get `init_message -> user -> llm -> user -> llm`` or `init_message -> llm -> user -> llm -> user`` - self.latest_llm_interaction_socket = self.filter_context_via_authors(["initialization", "llm", "env"]) - - env_turn = 1 - llm_turn = 1 - for index, ext_msg in enumerate(list(self.latest_llm_interaction_socket)): - is_last = (index == len(self.latest_llm_interaction_socket) - 1) - # 根据消息类型进行处理 - if ext_msg.author == "llm": - # 如果是以往的llm消息,去掉think标签 - new_ext_msg_content = self.strip_think_tags(ext_msg.content) - author_override = "llm(do_not_train)" - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=author_override, - role=ext_msg.role, - content=self.post_tag_llm_message_context(new_ext_msg_content, turn=llm_turn, is_last=is_last), - token_generator='auto', - tokenizer=self.tokenizer, - build_from_uuid=ext_msg.uuid, - ) - llm_turn += 1 - - # process env message - elif ext_msg.author == "env": - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=self.post_tag_env_message_context(content=ext_msg.content_for_future, turn=env_turn, is_last=is_last), - token_generator='auto', - tokenizer=self.tokenizer, - build_from_uuid=ext_msg.uuid, - ) - env_turn += 1 - - elif ext_msg.author in ["initialization"]: - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=self.post_tag_init_message_context(content=ext_msg.content_for_future, is_last=is_last), - token_generator='auto', - tokenizer=self.tokenizer, - build_from_uuid=ext_msg.uuid, - ) - - else: - raise RuntimeError(f"Unknown author {ext_msg.author} in latest_llm_interaction_socket") - - listofdict_context = self.to_role_content(self.latest_llm_interaction_socket) - return listofdict_context - - - def save_init_input(self, init_input_arr:list, add_nothink): - super().save_init_input(init_input_arr, add_nothink) - return - - - def impl_new_request_from_previous_interaction(self, new_message, this_interaction, strip_think=False): - latest_llm_interaction_socket_additional = copy.deepcopy(this_interaction) - if strip_think: - for index, ext_msg in enumerate(latest_llm_interaction_socket_additional): - if ext_msg.author == "llm(do_not_train)" or ext_msg.author == "llm": - latest_llm_interaction_socket_additional[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=self.strip_think_tags(ext_msg.content), - token_generator='auto', - tokenizer=self.tokenizer, - build_from_uuid=ext_msg.build_from_uuid if ext_msg.build_from_uuid else ext_msg.uuid, - ) - else: - continue - latest_llm_interaction_socket_additional += [new_message] - dict_context = self.to_role_content(latest_llm_interaction_socket_additional) - if self.train_sp_action: - llm_output = self.llm_chat_fn(dict_context, request_id="") - else: - llm_output = self.alien_llm_chat_fn(dict_context, request_id="") - latest_llm_interaction_socket_additional += [self.save_llm_output_do_not_register_full_context(llm_output, dict_context)] - if self.train_sp_action: - this_interaction = copy.deepcopy(latest_llm_interaction_socket_additional) - self.grouped_steps += [this_interaction] - if self.console_debug_mode: - print_listofdict( - dict_context + [{'role': 'llm_latest', 'content': llm_output['content']}], mod='c' - ) # print to console - else: - print_listofdict( - dict_context + [{'role': 'llm_latest', 'content': llm_output['content']}], mod='env_clip' - ) # log to file - output_llm_content = llm_output['content'].strip() - return latest_llm_interaction_socket_additional, output_llm_content - - - def after_save_llm_output(self, this_interaction): - """ - this_interaction = [ - init msg, - ..., - init msg, - ... - previous env msg, - latest llm msg, - ] - """ - from textwrap import dedent - if not self.latest_env_response_id: - return - - clip_token_cnt = self.config.astune.context_manager.auto_context_cm.token_num_trigger_clip - this_interaction = copy.deepcopy(this_interaction) - if self._get_seq_length(this_interaction) < clip_token_cnt: - return - - if self.clipped_before: - return - self.clipped_before = True - - _, generated_content = self.impl_new_request_from_previous_interaction( - new_message=ExtendedMessage( - author='user', - role='user', - content=dedent(""" - Your new task is to inspect each `Environment Response` and `Assistant Response` messages, - and determine whether each message is useful for the next-step decision-making. - Generate a json structure following the format below: - ```json - [ - {"id":"ARXXX or ERXXX", "useful":true or false, "action": "keep or remove or compress"}, - ..., - {"id":"ARXXX or ERXXX", "useful":true or false, "action": "keep or remove or compress"}, - ] - ``` - - For example: - ```json - [ - {"id":"ER001", "useful":true, "action": "keep"}, - {"id":"AR001", "useful":false, "action": "remove"}, - ... - ] - ``` - - Rules: - - If the message contains useful information for future decisions, set "useful":true and "action":"keep". - - If the message records important previous action or environment feedback, set "useful":true and "action":"keep". - - If the message is very long and very redundant, set "useful":true and "action":"compress". - - If the message is completely irrelevant, set "useful":false and "action":"remove". Note that important failures should be preserved, because learning from past is vital. - - Ignore messages without id=XXX tags, where XXX is a 3-digit number. - - Ensure the JSON is properly formatted and valid. - - Remove or compress at least one message, because token limit is already reached. - - At least remove (or compress) one message. - - There must be no more than 2 "compress" actions in total, because "compress" action will cost considerable amount of time. - - """), - token_generator='auto', - tokenizer=self.tokenizer, - ), - this_interaction=this_interaction, - strip_think=True, - ) - - try: - llm_output_content = generated_content = generated_content.strip() - if llm_output_content.count("```") == 2: - extracted_content: str = llm_output_content.split("```")[1].strip() - else: - raise RuntimeError(f"Cannot find ``` in llm_output content: {llm_output_content}") - if extracted_content.startswith('json'): - extracted_content = extracted_content[len('json'):].strip() - extracted_json = json.loads(extracted_content) - for item in extracted_json: - if 'id' not in item or 'useful' not in item or 'action' not in item: - raise RuntimeError(f"Each item must contain 'id', 'useful', and 'action' fields. Error in item: {item}") - message_id = item['id'] - message_action = item['action'] - # find message from self.full_context - ## match latest_llm_interaction_socket_additional and match self.full_context - from_uuid = None - for ext_msg in this_interaction: - if message_id in ext_msg.content_for_future: - from_uuid = ext_msg.build_from_uuid - break - if from_uuid is None: - raise ValueError(f"Cannot find message_id {message_id} in `this_interaction`") - target_msg = None - target_index = -1 - for index, msg in enumerate(self.full_context): - if msg.uuid == from_uuid: - target_msg = msg - target_index = index - break - if target_msg is None or target_index == -1: - raise ValueError(f"Cannot find message_id {message_id} in full_context") - - ## take actions - if message_action == 'remove': - self.full_context[target_index] = ExtendedMessage( - author=target_msg.author+"(discard)", - role=target_msg.role, - content=target_msg.content, # keep original content - token_generator='auto', - tokenizer=self.tokenizer, - ) - elif message_action == 'compress': - target_id = message_id - _, generated_compressed_content = self.impl_new_request_from_previous_interaction( - new_message=ExtendedMessage( - author='user', - role='user', - content=dedent(f""" - Your new task is to inspect {target_id}, and filter out all redundant information, and only keep the most important information that is useful for future decision-making. - For example, if the content is a long text with multiple paragraphs, you should only preseve the key paragraphs and use ... to replace the rest. - If the content is a long list of data / dict / json, you should only preseve the key items and use ... to replace the rest. - Be careful to preserve all information that might be useful in the future. You should at least reduce 50% of {target_id}. - Remember: wrap your answer with ``` - - Your response should be like: - ``` - ...content after filtering... - ``` - """), - token_generator='auto', - tokenizer=self.tokenizer, - ), - this_interaction=this_interaction[:-1], # exclude the latest llm message - strip_think=True, - ) - if generated_compressed_content.count("```") != 2: - raise RuntimeError(f"Cannot find ``` in llm_output content: {generated_compressed_content}") - compressed_content = generated_compressed_content.split("```")[1].strip() - self.full_context[target_index] = ExtendedMessage( - author=target_msg.author, - role=target_msg.role, - content=compressed_content, - token_generator='auto', - tokenizer=self.tokenizer, - ) - elif message_action == 'keep': - continue - else: - raise RuntimeError(f"Unknown action {message_action}, must be one of ['remove', 'keep', 'compress']") - - except Exception as e: - logger.bind(exception=True).exception(f"Error processing llm_output: {e}") - print(f"Error processing llm_output") - return - - - def replace_full_context_item(self, match_content: str, new_content: str): - success = False - for index in range(len(self.full_context)): - ext_msg = self.full_context[index] - if match_content in ext_msg.content_for_future: - success = True - self.full_context[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=new_content, - token_generator='auto', - tokenizer=self.tokenizer, - ) - # print_dict({match_content: new_content}) - return success - return success - - - def save_llm_output(self, llm_output, input_msg_ref): - ext_msg = CMTLinear.save_llm_output(self, llm_output, input_msg_ref) - this_interaction = copy.deepcopy(self.latest_llm_interaction_socket + [ext_msg]) - self.grouped_steps += [this_interaction] - self.after_save_llm_output(this_interaction) - self.latest_llm_interaction_socket = [] - return diff --git a/astune/context_manager/cmt_linear.py b/astune/context_manager/cmt_linear.py deleted file mode 100644 index 34d24c1e..00000000 --- a/astune/context_manager/cmt_linear.py +++ /dev/null @@ -1,653 +0,0 @@ -import torch -import copy -from loguru import logger -from collections import defaultdict -from typing import List, Union, Tuple, Optional -from astune.schema.trajectory import Sample, Reward -from astune.utils.compute_madness import compute_string_madness -from astune.context_manager.cmt_base_attr import CMTBaseAttr -from astune.context_manager.cmt_base_attr import ExtendedMessage -from astune.context_manager.cmt_base_attr import replace_token_ids -from beast_logger import register_logger, print_dict, print_listofdict, print_nested, NestedJsonItem, SeqItem - - -class CMTLinear(CMTBaseAttr): - """ - A linear context manager template that handles the conversation flow between LLM and environment. - This class manages the context window, tokenization, and message history in a linear fashion. - - Attributes: - config: Configuration object containing environment and model settings - tokenizer: Tokenizer instance for processing text - full_context (List[ExtendedMessage]): List of all messages in the conversation - current_context_status (str): Current status of the context - max_seq_length (int): Maximum sequence length for the context window - max_env_output_length (int): Maximum length for environment outputs - terminal_rewards_dict (dict): Dictionary storing terminal rewards - """ - - """ - 1. prepare_next_llm_context - 2. check_context_token_num_safe - 3. prepare_world_interaction - 4. save_init_input - 5. save_llm_output - 6. save_env_output - 7. remove_last_context - 8. generate_log - 9. group_tokenize - """ - - - - def prepare_previous_context(self, mod='future'): - """ - Prepare the input context for future LLM call. - - Returns: - list: Array of message dictionaries containing role and content_for_future, - formatted for LLM input. - """ - if mod=='future': - message_arr = [ - {"role": c.role, "content": c.content_for_future} - for c in self.full_context - ] - return message_arr - - elif mod=='raw': - message_arr = [ - {"role": c.role, "content": c.content} - for c in self.full_context - ] - return message_arr - - else: - raise ValueError(f"Unknown mod {mod} in prepare_previous_context, only support 'future' and 'raw'") - - - def check_context_token_num_safe(self, messages: List[dict]) -> Tuple[bool, str]: - def get_seq_length(messages): - prompt_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - return len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) - if self.already_mad_flag and self.config.astune.rollout.agent_madness_termination: - return False, "already_mad" - messages = self.prepare_previous_context(mod="raw") - if get_seq_length(messages) < self.max_seq_length: # self.config.env_engine.max_seq_length = 20480 - return True, "safe" - else: - return False, "token_overflow" - - - def get_inc(self, text_frag_from, text_frag_to): - """ - Get the incremental token array from text_frag_from to text_frag_to. - """ - tokenizer_output = self.tokenizer(text_frag_from, return_tensors="pt", padding=False) - tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist() - token_ids_acc = tokenizer_input_ids - - tokenizer_output = self.tokenizer(text_frag_to, return_tensors="pt", padding=False) - input_ids = tokenizer_output["input_ids"][0].tolist() - input_id_increment = input_ids[len(token_ids_acc):] # get the new tokens added in this step - overlap_length = 0 - for i in range(len(token_ids_acc)): - if i < len(token_ids_acc) and input_ids[i] == token_ids_acc[i]: overlap_length += 1 - else: break - msg = f"previous token length: {len(token_ids_acc)}, overlap token length: {(overlap_length)}, increment token length: {len(input_id_increment)}" - # print(msg) - return input_id_increment, msg - - def remove_last_context(self): - if len(self.full_context) > 0: - if self.full_context[-1].author != "llm": - self.full_context.pop(-1) - - def remove_last_non_llm_msg(self, ext_msg_list:List[ExtendedMessage]): - if len(ext_msg_list) > 0: - if ext_msg_list[-1].author != "llm": - ext_msg_list.pop(-1) - return ext_msg_list - - - - @property - def steps(self): - return self.prepare_previous_context(mod='future') - - - def prepare_next_llm_context(self): - return self.prepare_previous_context(mod='future') - - - def save_init_input(self, init_input_arr:list, add_nothink: bool=False): - """ - Save and process the initial input messages to the context. - - Args: - init_input_arr (list): Array of initial input messages to be processed - Each message should be a dict with 'role' and 'content' - - Note: - - Initializes the context with the provided messages - - Computes token arrays for each message - - Validates that the context is empty before saving - """ - # save basic - assert len(self.full_context) == 0, "full_context should be empty when saving init input" - for index, llm_msg in enumerate(init_input_arr): - if (index == len(init_input_arr) - 1) and add_nothink: - llm_msg['content'] += "\n/no_think" - ext_msg = ExtendedMessage( - author="initialization", - role=llm_msg['role'], - content=llm_msg['content'], - token_generator="manual", - tokenizer=self.tokenizer, - ) - self.full_context += [ext_msg] - - # compute token array for each message - token_ids_acc = [] - for llm_msg, ext_msg, index in zip(init_input_arr, self.full_context, range(len(init_input_arr))): - text_with_chat_template = self.tokenizer.apply_chat_template(init_input_arr[:(index+1)], tokenize=False) - tokenizer_output = self.tokenizer(text_with_chat_template, return_tensors="pt", padding=False) - input_ids = tokenizer_output["input_ids"][0].tolist() - # attention_mask = outputs["attention_mask"][0].tolist() - input_id_increment = input_ids[len(token_ids_acc):] # get the new tokens added in this step - overlap_length = 0 - for i in range(len(token_ids_acc)): - if (i < len(token_ids_acc)) and (input_ids[i] == token_ids_acc[i]): overlap_length += 1 - else: break - ext_msg._info = f"previous token length: {len(token_ids_acc)}, overlap token length: {(overlap_length)}, increment token length: {len(input_id_increment)}" - ext_msg.token_arr = input_id_increment - token_ids_acc += input_ids - return - - def save_llm_output(self, llm_output, input_msg_ref, auto_register_full_context=True): - """ - Save the output from the LLM to the full context. - - Args: - llm_output (dict): The output from the LLM containing 'role', 'content', and 'tokens' - input_msg_ref: Reference to the input messages for token increment calculation - out_of_full_context: Register in full_context or not - - Note: - - Processes the LLM output and adds it to the conversation history - - Handles token processing and generation prompt management - - Ensures proper tokenization and context maintenance - """ - # save basic - assert isinstance(llm_output, dict) - token_generator = "manual" if 'tokens' in llm_output else "auto" - ext_msg = ExtendedMessage( - author="llm", - role=llm_output['role'], - content=llm_output['content'], - token_generator=token_generator, - tokenizer=self.tokenizer, - ) - if auto_register_full_context: - self.full_context += [ext_msg] - if not self.already_mad_flag: - if compute_string_madness(completion=llm_output['content'], checklist=self.config.astune.rollout.compute_madness_checklist) < 0.0: - self.already_mad_flag = True - - if token_generator == "manual": - token_arr_method2, token_logprob_arr = self.get_token_inc_from_vllm_response(input_msg_ref, llm_output) - ext_msg.token_arr = token_arr_method2 - ext_msg.token_logprob_arr = token_logprob_arr - - return ext_msg - - # generate token - def get_token_inc_from_vllm_response(self, input_msg_ref, llm_output) -> Tuple[List[int], List[int]]: - generation_prompt_token, msg = self.get_inc( - self.tokenizer.apply_chat_template(input_msg_ref, tokenize=False, add_generation_prompt=False), - self.tokenizer.apply_chat_template(input_msg_ref, tokenize=False, add_generation_prompt=True), - ) - # completion_token_arr will contain generation_prompt header - completion_token_arr, msg2 = self.get_inc( - # ... <|im_end|> - self.tokenizer.apply_chat_template(input_msg_ref, tokenize=False), - # ... <|im_end|><|im_start|>...<|im_end|> - self.tokenizer.apply_chat_template(input_msg_ref + [ {"role": llm_output['role'], "content": llm_output['content']} ], tokenize=False), - ) - vllm_output_raw_token = [t.token_id for t in llm_output['tokens']] - vllm_output_raw_logprob = [t.logprob for t in llm_output['tokens']] - self.generated_token_cnt += len(vllm_output_raw_token) - final_token_arr, token_logprob_arr = replace_token_ids( - place_holder=completion_token_arr, - replace_with=vllm_output_raw_token, - begin=generation_prompt_token, - end=[self.tokenizer.eos_token_id], - raw_logprob=vllm_output_raw_logprob, - ) - return final_token_arr, token_logprob_arr - - def save_llm_output_do_not_register_full_context(self, llm_output, input_msg_ref): - return CMTLinear.save_llm_output(self, llm_output, input_msg_ref, auto_register_full_context=False) - - - def save_env_output(self, env_output:dict, input_msg_ref:Optional[List[dict]]=None, add_nothink=False): - """ - Save and process environment output to the context. - - Args: - env_output (dict): Environment output containing 'content' - input_msg_ref (List[dict], optional): Reference messages for token calculation - - Note: - - Clips environment output if it exceeds max_env_output_length - - Processes the output as a user message in the conversation - - Computes and stores token arrays for the environment response - """ - assert isinstance(env_output, dict) - if ('content' not in env_output) and ('error' in env_output): - env_output['content'] = f"[Error from environment: {env_output['error']}]" - elif ('content' not in env_output) or (not env_output['content']): - env_output['content'] = 'Warning: the environment does not provide any feedback, please provide valid inpu and try again.' - if add_nothink: - env_output['content'] += " /no_think" - ext_msg = ExtendedMessage( - author="env", - role="user", - content=env_output['content'], - clip=True, - clip_token_limit=self.max_env_output_length, - token_generator="auto", - tokenizer=self.tokenizer, - ) - self.full_context += [ext_msg] - return - - def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List[dict]: - return [{"role": ext_msg.role, "content": ext_msg.content_for_future} for ext_msg in ext_msg_array] - - def prepare_world_interaction(self) -> str: - """ - Process the latest model content before environment interaction. - - Returns: - str: Processed content, with code extracted from markdown code blocks if present - or the raw content if no code blocks are found - - Note: - - Extracts Python code from markdown code blocks (```python```) - - Returns the raw content if no valid code blocks are found - """ - latest_content = self.full_context[-1].content - return latest_content - - def filter_context_via_author(self, author: str) -> List[ExtendedMessage]: - return copy.deepcopy([ c for c in self.full_context if c.author == author ]) - - def filter_context_via_authors(self, authors: List[str]) -> List[ExtendedMessage]: - return copy.deepcopy([ c for c in self.full_context if c.author in authors ]) - - def filter_context_via_authors_with_limit(self, authors: List[str], limit: dict) -> List[ExtendedMessage]: - """ - limit = { - "llm": "keep_last@2" - "env": "keep_first@2" - } - """ - filtered_via_authors = copy.deepcopy([ c for c in self.full_context if c.author in authors ]) - for limit_author, limit_item in limit.items(): - limit_item_command, limit_item_value = limit_item.split('@') - if limit_item_command == "keep_last": - limit_item_value = int(limit_item_value) - # remove all message whose author is `llm_author` except the last `limit_item_value` messages - num_need_rm = len([ c for c in filtered_via_authors if c.author == limit_author ]) - limit_item_value - if num_need_rm > 0: - num_already_rm = 0 - filtered_via_authors_new = [] - for c in filtered_via_authors: - if c.author == limit_author: - num_already_rm += 1 - if num_already_rm <= num_need_rm: - continue - filtered_via_authors_new += [c] - filtered_via_authors = filtered_via_authors_new - - elif limit_item_command == "keep_first": - limit_item_value = int(limit_item_value) - # remove all message whose author is `llm_author` except the first `limit_item_value` messages - num_need_keep = len([ c for c in filtered_via_authors if c.author == limit_author ]) - limit_item_value - if num_need_keep > 0: - num_already_keep = 0 - filtered_via_authors_new = [] - for c in filtered_via_authors: - if c.author == limit_author: - num_already_keep += 1 - if num_already_keep > limit_item_value: - continue - filtered_via_authors_new += [c] - filtered_via_authors = filtered_via_authors_new - - else: - raise ValueError(f"Unknown limit_item_command {limit_item_command} in filter_context_via_authors_with_limit") - return filtered_via_authors - - def group_tokenize(self): - sample_arr = [] - ext_steps = self.full_context - cmt_tokenized = self.tokenize_steps(ext_steps=ext_steps, index=0, total_steps=1) - sample = Sample( - cmt_tokenized = cmt_tokenized, - messages=self.to_role_content(ext_steps), - config=self.config, - - task_batch_index=self.task_batch_index, - task_tag=self.task_tag, - task_id=self.task_id, - ) - sample.truncate_output_ids() - sample_arr += [sample] - return sample_arr - - def group_tokenize_multi_group(self): - sample_arr = [] - max_num_group = self.config.astune.rollout.multi_turn.max_sample_per_task - for index, ext_steps in enumerate(self.grouped_steps): - cmt_tokenized = self.tokenize_steps(ext_steps=ext_steps, index=index, total_steps=len(self.grouped_steps)) - sample = Sample( - cmt_tokenized = cmt_tokenized, - messages=self.to_role_content(ext_steps), - config=self.config, - - task_batch_index=self.task_batch_index, - task_tag=self.task_tag, - task_id=self.task_id, - ) - sample_arr += [sample] - - if len(sample_arr) > max_num_group: - print(f"Warning: allow {max_num_group} groups, but got {len(sample_arr)} groups") - import random - sample_arr = random.sample(sample_arr, max_num_group) # 随机保留 max_num_group 个组 - - return sample_arr - - def generate_log(self, task_id=None, global_step="NA"): - task_id = self.task_id - nested_items_print_buffer = {} - ext_steps=self.full_context - cmt_tokenized = self.tokenize_steps(ext_steps=ext_steps, index=0, total_steps=1) - text_arr = [self.tokenizer.decode(t) for t in cmt_tokenized["input_ids"]] - input_id_arr = [str(t) for t in cmt_tokenized["input_ids"]] - loss_mask_color_arr = ["#09ABCF" if mask==1 else "#D98510" for mask in cmt_tokenized["loss_mask"]] - buffer = { - "text_arr": text_arr, - "input_id_arr": input_id_arr, - "loss_mask_color_arr": loss_mask_color_arr, - } - len_prompt_ids = len(cmt_tokenized["prompt_ids"]) - len_response_ids = len(cmt_tokenized["response_ids"]) - len_input_ids = len(cmt_tokenized["input_ids"]) - raw_reward = self.reward_structure.raw_reward - step_reward = self.reward_structure.step_reward[0] - try: - step_advantage = self.reward_structure.step_advantage[index] - step_advantage_simple = self.reward_structure.step_advantage_simple[index] - except: - step_advantage = 0.0 - step_advantage_simple = 0.0 - task_outcome = str(self.reward_structure.success_rate) - selectors = [task_id, task_outcome] - nested_items_print_buffer[f".".join(selectors)] = NestedJsonItem( - item_id=f"item", # type: ignore - outcome=task_outcome, # type: ignore - len_prompt_ids=len_prompt_ids, # type: ignore - len_response_ids=len_response_ids, # type: ignore - len_input_ids=len_input_ids, # type: ignore - raw_reward=f"{float(raw_reward):.3f}", # type: ignore - step_reward=f"{float(step_reward):.3f}", # type: ignore - step_advantage=f"{float(step_advantage):.3f}", # type: ignore - step_advantage_simple=f"{float(step_advantage_simple):.3f}", # type: ignore - content=SeqItem( - text = buffer['text_arr'], # 文本 - title = buffer['text_arr'], # 鼠标悬浮文本 - count = buffer['input_id_arr'], # 高亮文本 - color = buffer['loss_mask_color_arr'] # 颜色 - ) - ) - print_nested(nested_items_print_buffer, - main_content="This is the main content of the nested JSON", - header=f"[{global_step}] Task {task_id} (Reward {float(step_reward):.3f})", - mod="rollout", - narrow=False - ) - - - def process_reward(self, reward_structure: Reward): - self.reward_structure = reward_structure - ext_steps = self.full_context - # lienar 模式只有一条轨迹 - self.reward_structure.step_reward = [ - self.compute_step_level_reward(ext_steps=ext_steps, index=0, total_steps=1) - ] - - - def ensure_terminate_rollout_stage(self): - """Nothing need to be done for basic linear cmt at `ensure_terminate_rollout_stage` - """ - pass - - def compute_step_level_reward(self, ext_steps: List[ExtendedMessage], index: int, total_steps:int)->float: - assert self.reward_structure is not None - - # --------------- global level reward --------------- - global_reward = self.reward_structure.raw_reward - gamma = self.config.astune.rollout.gamma - step_reward_base = global_reward * (gamma ** (total_steps - index - 1)) - - # --------------- compute step level reward --------------- - step_reward = step_reward_base - if self.already_mad_flag: - step_reward = self.config.astune.rollout.agent_madness_reward - self.reward_structure.madness = -1.0 - - return step_reward - - - # def compute_step_level_reward(self, ext_steps: List[ExtendedMessage], index: int, total_steps:int)->float: - # assert self.reward_structure is not None - # # --------------- global level reward --------------- - # global_reward = self.reward_structure.raw_reward - # # here we assume global reward is given at the end of the trajectory - # gamma = self.config.astune.rollout.gamma - # step_reward_base = global_reward * (gamma ** (total_steps - index - 1)) - # # when index=0, total_steps=1, step_reward = global_reward * (gamma ** 0) = global_reward - # # when index=0, total_steps=2, step_reward = global_reward * (gamma ** 1) = global_reward * 0.95 - # # when index=0, total_steps=3, step_reward = global_reward * (gamma ** 2) = global_reward * 0.9025 - - # # --------------- compute step level reward --------------- - # step_reward = step_reward_base - # # # get all ext_step that need to be trained - # # trainable_ext_steps = [ ext_msg for ext_msg in ext_steps if ext_msg.need_training ] - - # # # --------------- compute step level reward: response madness --------------- - # # # in some cases, a step may contain multiple messages that need training, therefore we define mini_step - # # mini_step_reward = [] - # # for ext_msg in trainable_ext_steps: - # # assert ext_msg.need_training, "trainable_ext_steps should only contain messages that need training" - # # mini_step_reward += [compute_string_madness(completion=ext_msg.content_for_future)] - - # # if any([r < 0 for r in mini_step_reward]): - # # self.reward_structure.madness = -1.0 - # # step_reward = self.config.astune.rollout.agent_madness_reward - # # else: - # # pass - # if self.already_mad_flag: - # step_reward = self.config.astune.rollout.agent_madness_reward - # self.reward_structure.madness = -1.0 - - # return step_reward - - - def tokenize_steps(self, ext_steps: List[ExtendedMessage], index:int, total_steps:int) -> dict: - """ - Create an Experience object from the current conversation context. - - Returns: - Experience: An object containing processed conversation data for model training - - Note: - - Removes the last user message as it's not required in casual model training - - Processes input IDs, attention masks, and loss masks - - Separates prompt and response components - - Handles position IDs and reward scores - - Truncates output IDs as needed - """ - from verl.utils.model import compute_position_id_with_mask - ext_steps = self.remove_last_non_llm_msg(ext_steps) - - # check reward structure - self.reward_structure: Reward # type: ignore - assert self.reward_structure.step_reward is not None, "must call `process_reward` before tokenize_steps" - assert len(self.reward_structure.step_reward) == total_steps - - # mapping - input_ids = [] - input_logprobs = [] - attention_mask = [] - loss_mask = [] - split_prompt_reponse_index = -1 - split_point_message_left_index = -1 - input_ids_len = [] - - # cat all messages - for i, ext_msg in enumerate(ext_steps): - # find split index, this have to be done before input_ids += ext_msg.token_arr - if (split_prompt_reponse_index == -1) and (ext_msg.need_training): - split_prompt_reponse_index = len(input_ids) - split_point_message_left_index = i - 1 - assert split_point_message_left_index >= 0, "There should be at least one message before the first training message" - assert split_prompt_reponse_index == input_ids_len[split_point_message_left_index] - assert ext_msg.author == 'llm', "The first message after initialization should be from LLM, not from env or user" - - # cat all tokens - input_ids += ext_msg.token_arr - if len(ext_msg.token_logprob_arr) == 0: - input_logprobs += [ext_msg.invalid_log_prob_value] * len(ext_msg.token_arr) - else: - input_logprobs += ext_msg.token_logprob_arr - input_ids_len += [len(input_ids)] - attention_mask += [1] * len(ext_msg.token_arr) - loss_mask += ext_msg.get_loss_mask(blackout_token_combo=self.blackout_token_combo) - - if split_prompt_reponse_index == -1: - # should we begin split point early? - if input_ids_len[-1] > self.config.astune.data.max_prompt_length: - message_dict = self.to_role_content(ext_steps) - logger.error(f"Input ids exceeded max_prompt_length before encountering any training message! trying to fix...") - logger.bind(exception=True).exception(f"Input ids exceeded max_prompt_length before encountering any training message! trying to fix...\n\n" + str(message_dict)) - assert i >= 1, "There should be at least one message before exceeding max_prompt_length" - assert input_ids_len[-2] <= self.config.astune.data.max_prompt_length, "The previous message should be within max_prompt_length, something is wrong" - split_point_message_left_index = i - 1 - assert split_point_message_left_index == (len(input_ids_len) - 2), "what?" - split_prompt_reponse_index = input_ids_len[split_point_message_left_index] - - # check - assert len(ext_steps) == len(input_ids_len), "length of ext_steps and input_ids_len should be equal" - assert split_prompt_reponse_index != -1, "split_prompt_reponse_index should not be -1, at least one message should be in the context" - position_ids = compute_position_id_with_mask(torch.tensor(attention_mask)).tolist() - - # sperate prompt and response - prompt_ids = input_ids[:split_prompt_reponse_index] - prompt_attention_mask = attention_mask[:split_prompt_reponse_index] - prompt_position_ids = position_ids[:split_prompt_reponse_index] - prompt_loss_mask = loss_mask[:split_prompt_reponse_index] - prompt_logprobs = input_logprobs[:split_prompt_reponse_index] - - response_ids = input_ids[split_prompt_reponse_index:] - response_attention_mask = attention_mask[split_prompt_reponse_index:] - response_position_ids = position_ids[split_prompt_reponse_index:] - response_loss_mask = loss_mask[split_prompt_reponse_index:] - response_logprobs = input_logprobs[split_prompt_reponse_index:] - - cmt_tokenized = {} - cmt_tokenized["input_ids"] = input_ids - cmt_tokenized["prompt_ids"] = prompt_ids - cmt_tokenized["response_ids"] = response_ids - cmt_tokenized["attention_mask"] = attention_mask - cmt_tokenized["logprobs"] = input_logprobs - cmt_tokenized["prompt_attention_mask"] = prompt_attention_mask - cmt_tokenized["response_attention_mask"] = response_attention_mask - cmt_tokenized["loss_mask"] = loss_mask - cmt_tokenized["prompt_loss_mask"] = prompt_loss_mask - cmt_tokenized["response_loss_mask"] = response_loss_mask - cmt_tokenized["position_ids"] = position_ids - cmt_tokenized["prompt_position_ids"] = prompt_position_ids - cmt_tokenized["response_position_ids"] = response_position_ids - cmt_tokenized["step_reward"] = self.reward_structure.step_reward[index] - cmt_tokenized["response_logprobs"] = response_logprobs - cmt_tokenized["prompt_logprobs"] = prompt_logprobs - try: - cmt_tokenized["reference_advantage"] = self.reward_structure.step_advantage[index] - except: - cmt_tokenized["reference_advantage"] = 0 - - return cmt_tokenized - - @staticmethod - def compute_reference_advantage(cmt_array: List): - import numpy as np - - task2cmt = defaultdict(list) - for cmt in cmt_array: - task2cmt[cmt.task_id] += [cmt] - - for task_id, cmt_list in task2cmt.items(): - cmt_reward = [] - - # 计算组内平均和标准差 - for cmt in cmt_list: - cmt_reward += [np.mean(cmt.reward_structure.step_reward)] - - if len(cmt_reward) == 1: - reward_mean = 0.0 - reward_std = 1.0 - else: - reward_mean = float(np.mean(cmt_reward)) - reward_std = float(np.std(cmt_reward, ddof=1)) - if reward_std < 0.01: - reward_std = 0.01 - - # logger.bind(exception=True).info(f"task id {task_id}") - # logger.bind(exception=True).info(f"reward_mean {reward_mean}, reward_std {reward_std}, cmt_reward {cmt_reward}") - # 计算 advantage - for cmt in cmt_list: - cmt.reward_structure.step_advantage = [] - for i in range(len(cmt.reward_structure.step_reward)): - cmt.reward_structure.step_advantage += [ - (cmt.reward_structure.step_reward[i] - reward_mean) / (reward_std + 1e-6) - ] - # logger.bind(exception=True).info(f"step reward {cmt.reward_structure.step_reward}") - # logger.bind(exception=True).info(f"step advantage {cmt.reward_structure.step_advantage}") - - # 计算简单advantage (不均衡rollout sample数量) - for task_id, cmt_list in task2cmt.items(): - cmt_reward = [] - for cmt in cmt_list: - cmt_reward.extend(cmt.reward_structure.step_reward) - if len(cmt_reward) == 1: - reward_mean = 0.0 - reward_std = 1.0 - else: - reward_mean = float(np.mean(cmt_reward)) - reward_std = float(np.std(cmt_reward, ddof=1)) - # if reward_std < 0.01: - # reward_std = 0.01 - for cmt in cmt_list: - cmt.reward_structure.step_advantage_simple = [] - for i in range(len(cmt.reward_structure.step_reward)): - cmt.reward_structure.step_advantage_simple += [ - (cmt.reward_structure.step_reward[i] - reward_mean) / (reward_std + 1e-6) - ] - - return - - diff --git a/astune/context_manager/cmt_linear_think.py b/astune/context_manager/cmt_linear_think.py deleted file mode 100644 index 8ea83592..00000000 --- a/astune/context_manager/cmt_linear_think.py +++ /dev/null @@ -1,311 +0,0 @@ -import copy -from textwrap import dedent -from typing import List, Tuple -from astune.schema.trajectory import Sample, Reward -from astune.context_manager.cmt_linear import ExtendedMessage, CMTLinear -from beast_logger import register_logger, print_dict, print_nested, NestedJsonItem, SeqItem - -class MultiSampleCMT(CMTLinear): - def __init__(self, config, tokenizer): - super().__init__(config, tokenizer) - self.config = config - self.tokenizer = tokenizer - self.full_context: List[ExtendedMessage] = [] - self.current_context_status = "" - max_response_length = self.config.astune.rollout.max_response_length_in_one_turn - max_model_len: int = self.config.astune.rollout.max_model_len - - assert self.config.astune.data.max_response_length < self.config.astune.data.max_prompt_length, "think linear template requires a big max_prompt_length" - - self.max_seq_length: int = max_model_len - max_response_length - assert self.max_seq_length <= self.config.astune.data.max_prompt_length, "max_seq_length should be less than or equal to max_prompt_length" - - - self.max_env_output_length: int = self.config.astune.rollout.max_env_len - self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") - - self.terminal_rewards_dict = {} - self.latest_llm_interaction_socket: List[ExtendedMessage] = None - self.grouped_steps: List[List[ExtendedMessage]] = [] - - self.discarded = False - self.is_terminated = False - self.context_time_cost = 0 - self.already_mad_flag = False - - self.force_think = config.astune.rollout.force_think - self.env_action_preference = config.astune.task_reader.env_service.env_action_preference - if not self.force_think: - # think_hint_for_qwen3 = - self.think_hint: str = "\n\nThink about the next step before answering. Your thought (...) should be as short and concise as possible." - else: - if self.env_action_preference == "box": - force_think_prompt = dedent(""" - Additional requirements: Think before action! You must think step by step before your next action, and you must use ... to wrap your thinking process before finally produce your answer with \\box{}. - For example: - ...your thinking process... - \\box{...your final answer...} - """) - elif self.env_action_preference == "code": - force_think_prompt = dedent(""" - Additional requirements: Think before action! You must think step by step before your next action, and you must use ... to wrap your thinking process before finally produce the next-step action. - For example: - ...your thinking process... - ```python - # your action here - ``` - """) - else: - raise ValueError(f"Unsupported env_action_preference: {self.env_action_preference}") - # think_hint_for_qwen2 = - self.think_hint: str = force_think_prompt - - def _get_seq_length(self, messages: List[dict]) -> int: - prompt_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - return len(self.tokenizer(prompt_text, return_tensors="pt", padding=False)["input_ids"][0]) - - def check_context_token_num_safe(self, messages: List[dict]) -> Tuple[bool, str]: - if self.already_mad_flag and self.config.astune.rollout.agent_madness_termination: - return False, "already_mad" - if self._get_seq_length(messages) < self.max_seq_length: # self.config.env_engine.max_seq_length = 20480 - return True, "safe" - else: - return False, "token_overflow" - - @property - def steps(self): - # TODO: need revise - return self.prepare_previous_context(mod='future') - - def generate_log(self, task_id = None, global_step="NA"): - task_id = self.task_id - nested_items_print_buffer = {} - for index, ext_steps in enumerate(self.grouped_steps): - cmt_tokenized = self.tokenize_steps(ext_steps=ext_steps, index=index, total_steps=len(self.grouped_steps)) - text_arr = [self.tokenizer.decode(t) for t in cmt_tokenized["input_ids"]] - input_id_arr = [str(t) for t in cmt_tokenized["input_ids"]] - loss_mask_color_arr = ["#09ABCF" if mask==1 else "#D98510" for mask in cmt_tokenized["loss_mask"]] - buffer = { - "text_arr": text_arr, - "input_id_arr": input_id_arr, - "loss_mask_color_arr": loss_mask_color_arr, - } - raw_reward = self.reward_structure.raw_reward - step_reward:float = self.reward_structure.step_reward[index] - try: - step_advantage = self.reward_structure.step_advantage[index] - step_advantage_simple = self.reward_structure.step_advantage_simple[index] - except: - step_advantage = 0.0 - step_advantage_simple = 0.0 - task_outcome = str(self.reward_structure.success_rate) - selectors = [task_id, task_outcome, str(index)] - len_prompt_ids = len(cmt_tokenized["prompt_ids"]) - len_response_ids = len(cmt_tokenized["response_ids"]) - len_input_ids = len(cmt_tokenized["input_ids"]) - assert len_prompt_ids + len_response_ids == len_input_ids, "len_prompt_ids + len_response_ids should equal to len_input_ids" - # print(f"Task {task_id}, outcome {task_outcome}, group {index}, len_prompt_ids {len_prompt_ids}, len_response_ids {len_response_ids}, len_input_ids {len_input_ids}") - nested_items_print_buffer[f".".join(selectors)] = NestedJsonItem( - item_id=f"item", - outcome=task_outcome, - len_prompt_ids=len_prompt_ids, - len_response_ids=len_response_ids, - len_input_ids=len_input_ids, - raw_reward=f"{float(raw_reward):.3f}", - step_reward=f"{float(step_reward):.3f}", - step_advantage=f"{float(step_advantage):.3f}", - step_advantage_simple=f"{float(step_advantage_simple):.3f}", - content=SeqItem( - text = buffer['text_arr'], # 文本 - title = buffer['text_arr'], # 鼠标悬浮文本 - count = buffer['input_id_arr'], # 高亮文本 - color = buffer['loss_mask_color_arr'] # 颜色 - ) - ) - print_nested(nested_items_print_buffer, - main_content="This is the main content of the nested JSON", - header=f"[{global_step}] Task {task_id} (Reward {float(step_reward):.3f})", - mod="rollout", - narrow=False, - attach="copy this" - ) - - - def group_tokenize(self): - return self.group_tokenize_multi_group() - - - def process_reward(self, reward_structure: Reward): - # lienar 模式有多条轨迹 - use_step_reward_from_env = self.config.astune.rollout.get("use_step_reward_from_env", False) - if not use_step_reward_from_env: - self.reward_structure = reward_structure - self.reward_structure.step_reward = [0.0 for _ in range(len(self.grouped_steps))] - for index, ext_steps in enumerate(self.grouped_steps): - self.reward_structure.step_reward[index] = self.compute_step_level_reward( - ext_steps=ext_steps, - index=index, - total_steps=len(self.grouped_steps) - ) - else: - step_reward = reward_structure.raw_step_reward - assert reward_structure.raw_step_reward - assert len(reward_structure.raw_step_reward) == len(self.grouped_steps), f"len(reward_structure.raw_step_reward) {len(reward_structure.raw_step_reward)} should equal to len(self.grouped_steps) {len(self.grouped_steps)}" - self.reward_structure = reward_structure - self.reward_structure.step_reward = reward_structure.raw_step_reward - - def compute_step_level_reward(self, ext_steps: List[ExtendedMessage], index: int, total_steps:int)->float: - assert self.reward_structure is not None - - # --------------- global level reward --------------- - global_reward = self.reward_structure.raw_reward - gamma = self.config.astune.rollout.gamma - step_reward_base = global_reward * (gamma ** (total_steps - index - 1)) - - # --------------- compute step level reward --------------- - step_reward = step_reward_base - if self.already_mad_flag: - step_reward = self.config.astune.rollout.agent_madness_reward - self.reward_structure.madness = -1.0 - - return step_reward - - - -class LinearThinkCMT(MultiSampleCMT): - """ - A linear context manager template that handles the conversation flow between LLM and environment. - This class manages the context window, tokenization, and message history in a linear fashion. - - Attributes: - config: Configuration object containing environment and model settings - tokenizer: Tokenizer instance for processing text - full_context (List[ExtendedMessage]): List of all messages in the conversation - current_context_status (str): Current status of the context - max_seq_length (int): Maximum sequence length for the context window - max_env_output_length (int): Maximum length for environment outputs - terminal_rewards_dict (dict): Dictionary storing terminal rewards - - """ - - - - - - def prepare_next_llm_context(self): - self.latest_llm_interaction_socket = [] - # 筛选出 `初始message-user-llm-user-llm`` 或者 `初始message-llm-user-llm-user`` - self.latest_llm_interaction_socket = self.filter_context_via_authors(["initialization", "llm", "env"]) - - for index, ext_msg in enumerate(list(self.latest_llm_interaction_socket)): - # is_last 是最后一条信息 - # remove history llm author's think (and add /no_think tag to every but last message) - is_last = (index == len(self.latest_llm_interaction_socket) - 1) - # 根据消息类型进行处理 - if ext_msg.author == "llm": - # 如果是以往的llm消息,去掉think标签 - import re - new_ext_msg_content = re.sub(r'.*?', '', ext_msg.content, flags=re.DOTALL).strip() - new_ext_msg_content = new_ext_msg_content.replace("", "") - new_ext_msg_content = new_ext_msg_content.replace("", "") - # new_ext_msg_content = re.sub(r'.*?', '\n\n', ext_msg.content, flags=re.DOTALL) - - if self.config.astune.context_manager.linear_think_cm.train_history_infer_token: - assert ext_msg.author == "llm" - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=new_ext_msg_content, - token_generator='auto', - tokenizer=self.tokenizer, - ) - else: - assert ext_msg.author == "llm" - author_override = "llm(do_not_train)" - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=author_override, - role=ext_msg.role, - content=new_ext_msg_content, - token_generator='auto', - tokenizer=self.tokenizer, - ) - elif ext_msg.author in ["env", "initialization"]: - if self.config.astune.context_manager.linear_think_cm.train_history_infer_token: - # 如果是初始化或者环境反馈,都加上 /no_think 标签 - if not is_last: - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=ext_msg.content_for_future + "\n/no_think", - token_generator='auto', - tokenizer=self.tokenizer, - ) - else: - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=ext_msg.content_for_future + self.think_hint, - token_generator='auto', - tokenizer=self.tokenizer, - ) - else: - # 如果是初始化或者环境反馈 - if not is_last: - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=ext_msg.content_for_future, - token_generator='auto', - tokenizer=self.tokenizer, - ) - else: - self.latest_llm_interaction_socket[index] = ExtendedMessage( - author=ext_msg.author, - role=ext_msg.role, - content=ext_msg.content_for_future + self.think_hint, - token_generator='auto', - tokenizer=self.tokenizer, - ) - else: - raise RuntimeError(f"Unknown author {ext_msg.author} in latest_llm_interaction_socket") - - dict_context = self.to_role_content(self.latest_llm_interaction_socket) - return dict_context - - - - def save_llm_output(self, llm_output, input_msg_ref): - ext_msg = super().save_llm_output(llm_output, input_msg_ref) - this_interaction = copy.deepcopy(self.latest_llm_interaction_socket + [ext_msg]) - self.grouped_steps += [this_interaction] - self.latest_llm_interaction_socket = [] - return ext_msg - - - def save_env_output(self, env_output:dict, input_msg_ref:List[dict]=None, add_nothink=False): - super().save_env_output(env_output, input_msg_ref, add_nothink) - return - - - def prepare_world_interaction(self) -> str: - latest_content = self.full_context[-1].content - if self.config.astune.context_manager.linear_think_cm.remove_think_before_submit_as_action: - import re - new_ext_msg_content = re.sub(r'.*?', '', latest_content, flags=re.DOTALL).strip() - new_ext_msg_content = new_ext_msg_content.replace("", "") - new_ext_msg_content = new_ext_msg_content.replace("", "") - latest_content = new_ext_msg_content.strip() - if self.config.astune.context_manager.linear_think_cm.extract_box_before_submit_as_action: - # take content within \box - # 提取 \box 中的内容 - import re - box_pattern = r'\\box\{(.*?)\}' - match = re.search(box_pattern, latest_content, re.DOTALL) - if match: - latest_content = match.group(1).strip() - else: - # 如果没有找到 \box,选择保留原内容 - pass - return latest_content - - diff --git a/astune/context_manager/cmt_phantom_hint.py b/astune/context_manager/cmt_phantom_hint.py deleted file mode 100644 index c10f7a14..00000000 --- a/astune/context_manager/cmt_phantom_hint.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import List -from astune.context_manager.cmt_linear import ExtendedMessage, CMTLinear - - -class LinearThinkCMT(CMTLinear): - """ - A linear context manager template that handles the conversation flow between LLM and environment. - This class manages the context window, tokenization, and message history in a linear fashion. - - Attributes: - config: Configuration object containing environment and model settings - tokenizer: Tokenizer instance for processing text - full_context (List[ExtendedMessage]): List of all messages in the conversation - current_context_status (str): Current status of the context - max_seq_length (int): Maximum sequence length for the context window - max_env_output_length (int): Maximum length for environment outputs - terminal_rewards_dict (dict): Dictionary storing terminal rewards - - 1. prepare_next_llm_context - 2. check_context_token_num_safe - 3. prepare_world_interaction - 4. save_init_input - 5. save_llm_output - 6. save_env_output - 7. remove_last_context - 8. generate_log - 9. group_tokenize - """ - - def __init__(self, config, tokenizer, contain_phantom_hint=False, past_trajectory=None): - super().__init__(config, tokenizer) - self.contain_phantom_hint = contain_phantom_hint - self.past_trajectory = past_trajectory - self.helper_llm_handle = None - - def save_init_input(self, init_input_arr:list, add_nothink: bool=False): - if self.contain_phantom_hint: - ... - - return super().save_init_input(init_input_arr, add_nothink) - diff --git a/astune/context_manager/cmt_sliding_window.py b/astune/context_manager/cmt_sliding_window.py deleted file mode 100644 index b24fc813..00000000 --- a/astune/context_manager/cmt_sliding_window.py +++ /dev/null @@ -1,205 +0,0 @@ -from typing import List, Callable, Tuple -from beast_logger import print_listofdict -from astune.context_manager.cmt_linear_think import ExtendedMessage, MultiSampleCMT -from loguru import logger - -""" -滑窗context管理器 -- 当context超出最大长度时,开始新的滑窗 -- 新的滑窗保留初始化信息、最新的env和llm信息 -- 其他信息忽略,生成一条“[Previous {x} conversation has been omitted for brevity.]”信息 -""" - - -class SlidingWindowCMT(MultiSampleCMT): - """ - A non-linear context manager template that handles the conversation flow between LLM and environment. - """ - - def __init__(self, config, tokenizer, llm_chat_fn): - self.llm_chat_fn = llm_chat_fn - self.latest_env_response_id = "" - self.latest_env_response_content = "" - self.console_debug_mode = False - self.force_think = config.astune.rollout.force_think - self.env_cnt = 0 - self.llm_cnt = 0 - self.config = config - self.tokenizer = tokenizer - self.full_context: List[ExtendedMessage] = [] - self.current_context_status = "" - max_response_length = self.config.astune.rollout.max_response_length_in_one_turn - max_model_len: int = self.config.astune.rollout.max_model_len - self.max_seq_length: int = max_model_len - max_response_length - self.max_env_output_length: int = self.config.astune.rollout.max_env_len - self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n") - self.terminal_rewards_dict = {} - self.latest_llm_interaction_socket: List[ExtendedMessage] = None - self.grouped_steps: List[List[ExtendedMessage]] = [] - self.discarded = False - self.is_terminated = False - self.context_time_cost = 0 - self.generated_token_cnt = 0 - self.omitted_msg_so_far = 0 - self.prompt_part_token_overflow = False - self.already_mad_flag = False - self.round_cnt = 0 - - def prepare_next_llm_context(self): - """Prepare the next LLM context with sliding window logic. When the token length exceeds the maximum limit, start a new sliding window. - """ - self.latest_llm_interaction_socket = self.filter_context_via_authors(["initialization", "llm", "llm(do_not_train)", "env", "memory"]) - dict_context = self.to_role_content(self.latest_llm_interaction_socket) - - # if token overflow, begin new sliding window - cur_seq_len = self._get_seq_length(dict_context) - # print(f"cur_seq_len {cur_seq_len}, self.max_seq_length {self.max_seq_length}") - - is_safe: bool = cur_seq_len < self.max_seq_length - if not is_safe: - _, previous_interaction = self._prepare_next_llm_context_static() - self.begin_new_sliding_window(previous_interaction=previous_interaction) - - dict_context, self.latest_llm_interaction_socket = self._prepare_next_llm_context_static() - cur_seq_len = self._get_seq_length(dict_context) - if cur_seq_len > self.config.astune.data.max_prompt_length: - print(f"Warning! cur_seq_len={cur_seq_len} immediately after new sliding window is created") - print_listofdict( - dict_context, mod='env_clip' - ) - self.prompt_part_token_overflow = True - - return dict_context - - - def _prepare_next_llm_context_static(self): - """Fetch from context and convert to dict format. - """ - latest_llm_interaction_socket = self.filter_context_via_authors(["initialization", "llm", "llm(do_not_train)", "env", "memory"]) - dict_context = self.to_role_content(latest_llm_interaction_socket) - return dict_context, latest_llm_interaction_socket - - - def check_context_token_num_safe(self, messages: List[dict]) -> Tuple[bool, str]: - """Always be safe because we already check in `prepare_next_llm_context` - """ - if self.already_mad_flag and self.config.astune.rollout.agent_madness_termination: - return False, "already_mad" - - assert self._get_seq_length(messages) < self.max_seq_length - - if self.prompt_part_token_overflow: - return False, "prompt_part_token_overflow" - else: - return True, "safe" - - - def begin_new_sliding_window(self, previous_interaction): - """Begin a new sliding window by preserving initialization, latest env and llm messages, and summarizing the rest into a memory message. - """ - self.grouped_steps += [previous_interaction] - recall_x_action = 2 - # delete most `llm` and `env` messages, keep only the last 2 of each - preserve_messages = self.filter_context_via_authors_with_limit( - authors = ["initialization", "llm", "env", "memory"], - limit={ - "llm": f"keep_last@{recall_x_action}", - "env": f"keep_last@{recall_x_action+1}", - "memory": "keep_last@1", - } - ) - other_messages = [ext_msg for ext_msg in self.filter_context_via_authors(authors = ["initialization", "llm", "env", "memory"]) if ext_msg not in preserve_messages] - # TODO: find a way to summarize previous messages - self.omitted_msg_so_far += len(other_messages) - # init message in `preserve_messages` - init_message_in_preserve_messages = [msg for msg in preserve_messages if msg.author == "initialization"] - # create memory message - other_messages = init_message_in_preserve_messages + other_messages # include init when create memory - memory_msg = self.create_memory_message(other_messages) - # inseart `preserve_messages` after initialization - new_context_beginning = init_message_in_preserve_messages + [memory_msg] + [msg for msg in preserve_messages if msg.author != "initialization"] - # disable llm training for all message in `new_context_beginning` - for i in range(len(new_context_beginning)): - ext_msg = new_context_beginning[i] - if ext_msg.author == 'llm': - author_override = "llm(do_not_train)" - new_context_beginning[i] = ExtendedMessage( - author=author_override, - role=ext_msg.role, - content=ext_msg.content_for_future, - token_generator='auto', - tokenizer=self.tokenizer, - ) - self.full_context = new_context_beginning - # delete old memory message - self.full_context = self.filter_context_via_authors_with_limit( - authors = ["initialization", "llm", "llm(do_not_train)", "env", "memory"], - limit={ - "memory": "keep_last@1", - } - ) - - def create_memory_message(self, msg_list: List[ExtendedMessage]) -> ExtendedMessage: - """TODO: create a better summary message - """ - x = self.omitted_msg_so_far // 2 - enable_llm_memory_extraction = self.config.astune.context_manager.sliding_window_cm - if not enable_llm_memory_extraction: - return ExtendedMessage( - author="memory", - role="user", - content=f"[Previous {x} round of conversations have been omitted for brevity.]", - token_generator='auto', - tokenizer=self.tokenizer, - ) - else: - return ExtendedMessage( - author="memory", - role="user", - content=self.llm_memory_extraction(msg_list), - token_generator='auto', - tokenizer=self.tokenizer, - ) - - def ensure_terminate_rollout_stage(self): - previous_interaction_dict_context, previous_interaction = self._prepare_next_llm_context_static() - if any([ext_msg.need_training for ext_msg in previous_interaction]): - self.grouped_steps += [previous_interaction] - - - def save_env_output(self, env_output, input_msg_ref = None, add_nothink=False): - self.env_cnt += 1 - env_output['content'] = f"[Current Env Step {self.env_cnt}]\n\n" + env_output['content'] - return super().save_env_output(env_output, input_msg_ref, add_nothink) - - - def save_llm_output(self, llm_output, input_msg_ref, auto_register_full_context=True): - self.llm_cnt += 1 - return super().save_llm_output(llm_output, input_msg_ref) - - - def llm_memory_extraction(self, msg_list: List[ExtendedMessage]) -> str: - """Use LLM to extract memory from previous messages. - """ - from astune.context_manager.cmt_foreign_llm import construct_alien_llm_chat_fn - from textwrap import dedent - self.alien_llm_chat_fn: Callable = construct_alien_llm_chat_fn(self.config, self.config.actor_rollout_ref.rollout) - messages = self.to_role_content(msg_list) - messages.append({ - "role": "user", - "content": dedent(""" - New task: Summarize the previous attempts into a concise memory statement that captures the key points and context. - - Start with: Previously, X attempts have been made, in these attempts, ... - - Focus on the main events, actions, and outcomes. - - If there are big or repeated failures, try to find reason and provide some future advice. - """) - }) - - try: - llm_output = self.alien_llm_chat_fn(messages, request_id="") - except Exception as e: - logger.bind(exception=True).exception(f"call alien_llm_chat_fn error with {e}") - x = self.omitted_msg_so_far // 2 - llm_output_content = f"[Previous {x} round of conversations have been omitted for brevity.]" - return llm_output_content - return llm_output['content'] diff --git a/astune/default_config/config_auto_convertion_trinity.json b/astune/default_config/config_auto_convertion_trinity.json deleted file mode 100644 index 231accd6..00000000 --- a/astune/default_config/config_auto_convertion_trinity.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "astune.model.path": "model.model_path", - "astune.project_name": "project", - "astune.experiment_name": "name" -} \ No newline at end of file diff --git a/astune/default_config/config_auto_convertion_verl.json b/astune/default_config/config_auto_convertion_verl.json deleted file mode 100644 index 7a2db5c6..00000000 --- a/astune/default_config/config_auto_convertion_verl.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "astune.model.path": "actor_rollout_ref.model.path", - "astune.project_name": "trainer.project_name", - "astune.experiment_name": "trainer.experiment_name" -} \ No newline at end of file diff --git a/astune/default_config/default.yaml b/astune/default_config/default.yaml deleted file mode 100644 index ae62dad9..00000000 --- a/astune/default_config/default.yaml +++ /dev/null @@ -1,151 +0,0 @@ -astune: - model: - path: /mnt/data/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct - data: - max_prompt_length: 3000 - max_response_length: 15000 - train_batch_size: 32 - rollout: - use_agentscope_protocol: True - agentscope_learn_protocol: null - use_step_reward_from_env: False - binary_reward: False - force_no_think: False - force_think: False - compute_madness_checklist: - - "nonsense" - gamma: 1.0 - agent_madness_termination: True # terminate_after_gone_mad - agent_madness_reward: -1.0 # customize the reward when agent is detected as gone mad - add_special_success_reward: False - temperature: 0.9 - top_p: 1.0 - max_env_len: 4096 - max_response_length_in_one_turn: 4096 - max_model_len: 18000 - multi_turn: - max_sample_per_task: 30 - max_steps: 30 - step_skip_action: 0 # skip action generation every N steps, 0 means never skip - submit_oversample_multiplier: 1.5 - enable_oversample: True - num_repeat: 4 - val_kwargs: - temperature: 0.0 - top_k: -1 - top_p: 1.0 - context_manager: # context manager protocol is used ONLY when `use_agentscope_protocol=False` - context_manager_type: "linear" - alien_llm_model: qwen3-235b-a22b-instruct-2507 - alien_llm_response_length: 512 - auto_context_cm: - train_sp_action: False - token_num_trigger_clip: 8000 - sliding_window_cm: - enable_llm_memory_extraction: False - linear_think_cm: - remove_think_before_submit_as_action: False - extract_box_before_submit_as_action: False - train_history_infer_token: True - - - -########################## verl config below ########################## -trainer: - val_before_train: False - hfmodelpath: "" - experiment_name: "read_yaml_name" - n_gpus_per_node: 8 - nnodes: 1 - save_freq: 20 - test_freq: 20 - total_epochs: 50 - project_name: appworldnew - validation_data_dir: "experiments/exp_default/validation_log" - rollout_data_dir: "experiments/exp_default/rollout_log" - critic_warmup: 0 - eval_pass_n: 4 - logger: - - console - - swanlab - - -data: - val_batch_size: 100000000000 - return_raw_chat: True - filter_overlong_prompts: True - truncation: error - fast_eval: True - train_batch_size: 32 - max_prompt_length: 3000 - max_response_length: 15000 - -algorithm: - task_norm_patch: False - adv_estimator: grpo - use_kl_in_reward: False - -actor_rollout_ref: - hybrid_engine: True - actor: - entropy_coeff: 0 - loss_agg_mode: seq-mean-token-mean - override_ppo_mini_batch_num: 1 - ppo_epochs: 1 - ppo_mini_batch_size: 16 - optim: - lr: 1e-6 - use_kl_loss: True - kl_loss_coef: 0.002 - kl_loss_type: low_var_kl - ppo_micro_batch_size_per_gpu: 1 - ppo_max_token_len_per_gpu: 18000 - use_dynamic_bsz: True - fsdp_config: - param_offload: True - optimizer_offload: True - - rollout: - name: vllm - mode: async - max_env_len: 3000 - response_length: 3000 - prompt_length: 15000 - max_model_len: 18000 - use_agentscope_protocol: False - ppo_micro_batch_size_per_gpu: 1 - tensor_model_parallel_size: 1 - max_num_seqs: 10 - gpu_memory_utilization: 0.9 - max_env_worker: 64 - log_prob_max_token_len_per_gpu: 18000 - temperature: 0.9 - top_p: 1.0 - gamma: 1.0 - enforce_eager: True - log_prob_micro_batch_size_per_gpu: 4 - multi_turn: - completion_callback: beyondagent.module.trainer.simple_completion_callback.SimpleCompletionCallback - enable: True - format: llama3_json - max_steps: 30 - tool_config_path: null - custom_dataflow_cls: - path: "" - name: "" - val_kwargs: - top_k: -1 - top_p: 1.0 - temperature: 0 - do_sample: False - ref: - use_dynamic_bsz: True - log_prob_micro_batch_size_per_gpu: 4 - log_prob_max_token_len_per_gpu: 18000 - fsdp_config: - param_offload: True - - model: - use_remove_padding: True - enable_gradient_checkpointing: True - diff --git a/astune/default_config/trinity_default.yaml b/astune/default_config/trinity_default.yaml deleted file mode 100644 index ab524b8e..00000000 --- a/astune/default_config/trinity_default.yaml +++ /dev/null @@ -1,68 +0,0 @@ -algorithm: - algorithm_type: multi_step_grpo - optimizer: - lr: 1e-6 - repeat_times: 6 -buffer: - batch_size: 8 - explorer_input: - eval_tasksets: [] - taskset: - default_workflow_type: astune_workflow - format: - prompt_key: question - response_key: answer - name: gsm8k - path: http://localhost:8080 - rollout_args: - temperature: 1.0 - split: train - storage_type: astune - subset_name: appworld - total_epochs: 1000 - train_batch_size: 36 - trainer_input: - experience_buffer: - max_read_timeout: 18000 - name: agentscope_gsm8k_buffer - storage_type: queue -checkpoint_root_dir: ./trinity_checkpoints -cluster: - gpu_per_node: 8 - node_num: 1 -explorer: - eval_interval: 999999 - max_repeat_times_per_runner: 1 - max_timeout: 7200 - rollout_model: - dtype: bfloat16 - enable_auto_tool_choice: true - enable_history: true - enable_openai_api: true - enable_prefix_caching: false - enable_thinking: false - enforce_eager: true - engine_num: 2 - seed: 42 - tensor_parallel_size: 1 - tool_call_parser: hermes - runner_per_model: 12 -model: - max_model_len: 21000 - max_response_tokens: 16000 - model_path: /mnt/data/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct -monitor: - monitor_type: swanlab -name: git-math-agentscope -project: appworld_astune -synchronizer: - sync_interval: 2 - sync_method: nccl - sync_style: dynamic_by_explorer - sync_timeout: 1200 -trainer: - grad_clip: 1.0 - max_token_len_per_gpu: 24576 - save_interval: 100 - ulysses_sequence_parallel_size: 2 - use_dynamic_bsz: true diff --git a/astune/env.py b/astune/env.py deleted file mode 100644 index 68fb07ae..00000000 --- a/astune/env.py +++ /dev/null @@ -1,114 +0,0 @@ -import uuid -import random - -from omegaconf import DictConfig -from astune.env_service_client.env_client_ng import EnvClient as EnvClientNg -from astune.agent_flow import AgentFlow -from astune.context_manager.cmt_linear import CMTLinear -from loguru import logger -from typing import List, Union -from astune.agentscope_flow import AgentScopeWorkflow -from astune.agent_flow import AgentFlow -from astune.agent_flow import BaseAgentFlow - -class EnvWorker(object): - - def __init__(self, task_core_arg, config: DictConfig): - self.config = config - - if config.astune.task_reader.type == 'env_service': - url = config.astune.task_reader.env_service.env_url - env_type = config.astune.task_reader.env_service.env_type - self.env = EnvClientNg(base_url=url) - self.env_params = {} - self.env_type: str = env_type - else: - self.env = None - - self.task_core_arg = task_core_arg - self.task_id: str = task_core_arg.task_id - self.tokenizer = task_core_arg.tokenizer - self.llm_chat_fn = task_core_arg.llm_chat_fn - self.obs_window = task_core_arg.obs_window - - def execute(self) -> CMTLinear: - - # >>>>>>>>>>>>>> create - if self.config.astune.task_reader.type == 'env_service': - try: - init_response = self.env.create_instance( - env_type=self.env_type, - task_id=self.task_id, - instance_id=self.task_core_arg.task_env_uuid, - params=self.env_params - ) - state_message: dict = init_response["state"] - _, init_messages = self.get_init_messages(state_message) - except Exception as e: - logger.bind(exception=True).exception(f"encounter exception in env_worker.create_instance~ error={e.args}") - self.env.release_instance(self.task_core_arg.task_env_uuid) - raise e - else: - task = self.task_core_arg.task - if task.init_messages: - init_messages = task.init_messages - else: - assert task.main_query, "You must provide init_messages or main_query in task." - init_messages = [{"role": "user", "content": task.main_query}] - - # =============== simulate - try: - - if not self.config.astune.rollout.use_agentscope_protocol: - agent_flow: BaseAgentFlow = AgentFlow(llm_chat_fn=self.llm_chat_fn, tokenizer=self.tokenizer, config=self.config) - else: - agent_flow: BaseAgentFlow = AgentScopeWorkflow(llm_chat_fn=self.llm_chat_fn, tokenizer=self.tokenizer, config=self.config) - - cmt = agent_flow.execute( - init_messages=init_messages, - env=self.env, # type:ignore || self.env: Union[EnvClient, EnvClientNg] - task_core_arg=self.task_core_arg - ) - cmt.task_batch_index = self.task_core_arg.task_batch_index - cmt.task_tag = self.task_core_arg.task_tag - cmt.task_id = self.task_id - - except Exception as e: - logger.bind(exception=True).exception(f"encounter exception in env_worker.agent_flow~ error={e.args}") - if self.env: self.env.release_instance(self.task_core_arg.task_env_uuid) - raise e - - # <<<<<<<<<<<<<< destory - try: - if self.env: self.env.release_instance(self.task_core_arg.task_env_uuid) - except Exception as e: - logger.bind(exception=True).exception(f"encounter exception in env_worker.release_instance~ error={e.args}") - raise e - - return cmt - - - def get_init_messages(self, state_message) -> tuple: - """ - Process state_message to extract query and init_messages. - - Args: - state_message (Union[dict, list]): The state message to process - - Returns: - tuple: (query, init_messages) where query is a string and init_messages is a list - - Raises: - ValueError: If state_message is neither dict nor list - """ - if isinstance(state_message, dict): - query = state_message["content"] - init_messages = [state_message] - elif isinstance(state_message, list): - assert isinstance(state_message[0], dict) - query = state_message[-1]["content"] - init_messages = state_message - else: - raise ValueError(f"state_message should be dict or list, but got {type(state_message)}") - - return query, init_messages diff --git a/astune/env_service_client/em_client.py b/astune/env_service_client/em_client.py deleted file mode 100644 index a6898449..00000000 --- a/astune/env_service_client/em_client.py +++ /dev/null @@ -1,90 +0,0 @@ -import asyncio -import time -from typing import List - -from loguru import logger -from pydantic import Field - -from astune.schema.trajectory import Trajectory, Reward -from astune.utils.http_client import HttpClient - - -class EMClient(HttpClient): - base_url: str = Field(default="http://localhost:8001") - timeout: int = Field(default=1200 , description="request timeout, second") - - def call_context_generator(self, trajectory: Trajectory, retrieve_top_k: int = 1, workspace_id: str = "default", - **kwargs) -> str: - start_time = time.time() - self.url = self.base_url + "/context_generator" - json_data = { - "trajectory": trajectory.model_dump(), - "retrieve_top_k": retrieve_top_k, - "workspace_id": workspace_id, - "metadata": kwargs - } - response = self.request(json_data=json_data, headers={"Content-Type": "application/json"}) - if response is None: - logger.warning("error call_context_generator") - return "" - - # TODO return raw experience instead of context @jinli - trajectory.metadata["context_time_cost"] = time.time() - start_time - return response["context_msg"]["content"] - - async def async_call_context_generator(self, executor=None, **kwargs): - loop = asyncio.get_event_loop() - - def func(): - return self.call_context_generator(**kwargs) - - return await loop.run_in_executor(executor=executor, func=func) - - def call_summarizer(self, trajectories: List[Trajectory], workspace_id: str = "default", **kwargs): - start_time = time.time() - - self.url = self.base_url + "/summarizer" - json_data = { - "trajectories": [x.model_dump() for x in trajectories], - "workspace_id": workspace_id, - "metadata": kwargs - } - response = self.request(json_data=json_data, headers={"Content-Type": "application/json"}) - if response is None: - logger.warning("error call_context_generator") - return "", time.time() - start_time - - return response["experiences"], time.time() - start_time - - async def async_call_summarizer(self, executor=None, **kwargs): - loop = asyncio.get_event_loop() - - def func(): - return self.call_summarizer(**kwargs) - - return await loop.run_in_executor(executor=executor, func=func) - - -def main(): - client = EMClient() - traj = Trajectory( - steps=[ - { - "role": "user", - "content": "What is the capital of France?" - }, - { - "role": "assistant", - "content": "Paris" - } - ], - query="What is the capital of France?", - reward=Reward(outcome=1.0) - ) - workspace_id = "w_agent_enhanced" - - print(client.call_summarizer(trajectories=[traj], workspace_id=workspace_id)) - print(client.call_context_generator(traj, retrieve_top_k=3, workspace_id=workspace_id)) - -if __name__ == "__main__": - main() diff --git a/astune/env_service_client/env_client.py b/astune/env_service_client/env_client.py deleted file mode 100644 index 656acc62..00000000 --- a/astune/env_service_client/env_client.py +++ /dev/null @@ -1,131 +0,0 @@ -# env_client.py -from typing import Dict, List, Any - -import requests - - -class EnvClient: - def __init__(self, base_url: str = "http://localhost:8000"): - self.base_url = base_url.rstrip("/") - self.timeout = 300.0 - - def _make_request( - self, - endpoint: str, - env_type: str = "default", - task_id: str = None, - instance_id: str = None, - messages: Dict[str, Any] = None, - params: Dict[str, Any] = None, - ) -> Dict: - """统一的请求处理方法""" - url = f"{self.base_url}/{endpoint}" - data = { - "env_type": env_type, - "task_id": task_id, - "instance_id": instance_id, - "messages": messages or {}, - "params": params or {}, - } - try: - response = requests.post(url, json=data, timeout=self.timeout) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - raise Exception(f"Request failed: {str(e)}, data: {data}") - - def get_task_ids( - self, env_type: str, split: str = "train", params: dict | None = None - ) -> List[str]: - """获取任务ID列表""" - payload: dict = {"env_type": env_type} - if params: - payload["params"] = params - response = self._make_request( - endpoint="get_task_ids", env_type=env_type, params={"split": split} - ) - return response["data"] - - def get_tools_info( - self, instance_id: str, messages: Dict = {}, params: Dict = {} - ) -> float: - """获取环境信息""" - response = self._make_request( - endpoint="get_info", - instance_id=instance_id, - messages=messages, - params=params, - ) - return response["data"] - - def create_instance( - self, env_type: str, task_id: str, instance_id: str = None, params: Dict = None - ) -> dict: - """创建环境实例""" - response = self._make_request( - endpoint="create", - env_type=env_type, - task_id=task_id, - instance_id=instance_id, - params=params, - ) - return response["data"] - - def step(self, instance_id: str, action: Dict = {}, params: Dict = {}) -> dict: - """执行环境步骤""" - response = self._make_request( - endpoint="step", instance_id=instance_id, messages=action, params=params - ) - return response["data"] - - def evaluate( - self, instance_id: str, messages: Dict = {}, params: Dict = {} - ) -> float: - """评估环境实例""" - response = self._make_request( - endpoint="evaluate", - instance_id=instance_id, - messages=messages, - params=params, - ) - return response["data"] - - def release_instance(self, instance_id: str) -> bool: - """释放环境实例""" - response = self._make_request(endpoint="release", instance_id=instance_id) - return response["success"] - - -# 使用示例 -def main(): - client = EnvClient() - - env_type = "appworld" - # 获取任务列表 - task_ids = client.get_task_ids(env_type) - print(f"Available tasks: {task_ids}") - - # 创建实例 - task_id = task_ids[0] - init_response = client.create_instance(env_type, task_id) - print("init state", init_response) - instance_id = init_response["info"]["instance_id"] - query = init_response["state"] - print(f"Created instance {instance_id} with query: {query}") - - # 执行动作 - action = {"role": "assistant", "content": "print('hello appworld!!')"} - result = client.step(instance_id, action) - print(f"Step result: {result}") - - # 评估 - score = client.evaluate(instance_id) - print(f"Evaluation score: {score}") - - # 释放实例 - success = client.release_instance(instance_id) - print(f"Instance released: {success}") - - -if __name__ == "__main__": - main() diff --git a/astune/main_trinity.py b/astune/main_trinity.py deleted file mode 100644 index e3c63d93..00000000 --- a/astune/main_trinity.py +++ /dev/null @@ -1,338 +0,0 @@ -""" -Modified from trinity.cli.launcher -""" - -import argparse -import asyncio -import os -import sys -import traceback -from pathlib import Path -from pprint import pprint - -import ray - -from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline -from trinity.common.config import Config, load_config -from trinity.common.constants import DEBUG_NAMESPACE, PLUGIN_DIRS_ENV_VAR -from trinity.explorer.explorer import Explorer -from trinity.manager.state_manager import StateManager -from trinity.trainer.trainer import Trainer -from trinity.utils.dlc_utils import is_running, setup_ray_cluster, stop_ray_cluster -from trinity.utils.log import get_logger -from trinity.utils.plugin_loader import load_plugins - -# register trinity backbone modules -import astune.backbone_trinity.register_flow # noqa: F401 - -logger = get_logger(__name__) - - -def bench(config: Config) -> None: - """Evaluate model.""" - config.explorer.name = "benchmark" - try: - explorer = Explorer.get_actor(config) - ray.get(explorer.prepare.remote()) - ray.get(explorer.benchmark.remote()) - logger.info("Benchmark finished.") - ray.get(explorer.shutdown.remote()) - except Exception: - logger.error(f"Benchmark failed:\n{traceback.format_exc()}") - - -def explore(config: Config) -> None: - """Run explorer.""" - try: - explorer = Explorer.get_actor(config) - ray.get(explorer.prepare.remote()) - ray.get(explorer.sync_weight.remote()) - ray.get(explorer.explore.remote()) - ray.get(explorer.shutdown.remote()) - except Exception: - logger.error(f"Explorer failed:\n{traceback.format_exc()}") - - -def train(config: Config) -> None: - """Run trainer.""" - try: - trainer = Trainer.get_actor(config) - ray.get(trainer.prepare.remote()) - ray.get(trainer.sync_weight.remote()) - ray.get(trainer.train.remote()) - ray.get(trainer.shutdown.remote()) - except Exception: - logger.error(f"Trainer failed:\n{traceback.format_exc()}") - - -def serve(config: Config) -> None: - """Run explorer in server mode.""" - try: - explorer = Explorer.get_actor(config) - ray.get(explorer.prepare.remote()) - ray.get(explorer.sync_weight.remote()) - ray.get(explorer.serve.remote()) - ray.get(explorer.shutdown.remote()) - except Exception: - logger.error(f"Explorer failed:\n{traceback.format_exc()}") - - -def both(config: Config) -> None: - """Setup both explorer and trainer. - - For the explorer, a step contains `batch_size * sync_interval` number - of rollout tasks. - - For the trainer, it has to consume all experiences generated by the explorer in - the latest step. The specific number of experiences may vary for different - algorithms and tasks. - """ - try: - explorer = Explorer.get_actor(config) - trainer = Trainer.get_actor(config) - ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()]) - ray.get( - [ - explorer.prepare.remote(), - trainer.prepare.remote(), - ] - ) - ray.get( - [ - explorer.sync_weight.remote(), - trainer.sync_weight.remote(), - ] - ) - ready_ref, wait_ref = ray.wait( - [ - explorer.explore.remote(), - trainer.train.remote(), - ], - num_returns=1, - ) - - ready = ray.get(ready_ref[0]) - if ready == config.trainer.name: - logger.info( - "===========================================================\n" - "> Launcher detected that the `Trainer` process has finished.\n" - "> Stopping the explorer process immediately.\n" - "===========================================================" - ) - ray.wait(wait_ref, timeout=5) - elif ready == config.explorer.name: - logger.info( - "===============================================================\n" - "> Launcher detected that the `Explorer` process has finished.\n" - "> `Trainer` process may need to save the model checkpoint.\n" - f"> Waiting {config.synchronizer.sync_timeout} s for the trainer process...\n" - "> You can force stop the `Trainer` process by pressing Ctrl+C.\n" - "===============================================================" - ) - ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout) - ray.wait( - [explorer.shutdown.remote(), trainer.shutdown.remote()], - timeout=config.synchronizer.sync_timeout, - num_returns=2, - ) - except Exception: - logger.error(f"Explorer or Trainer failed:\n{traceback.format_exc()}") - - -MODE_MAP = { - "explore": explore, - "train": train, - "both": both, - "bench": bench, - "serve": serve, -} - - -def run_stage(config: Config) -> None: - ray.init( - address=config.cluster.ray_address, - ignore_reinit_error=True, - namespace=config.ray_namespace, - runtime_env={"env_vars": config.get_envs()}, - ) - pprint(config) - try: - check_and_run_task_pipeline(config) - MODE_MAP[config.mode](config) - finally: - if config.monitor.enable_ray_timeline: - timeline_file = os.path.join(config.monitor.cache_dir, "timeline.json") - logger.info(f"Exporting Ray timeline to {timeline_file}...") - ray.timeline(filename=timeline_file) - logger.info("Done. You can open the timeline file in `chrome://tracing`") - ray.shutdown() - - -def run(config_path: str, dlc: bool = False, plugin_dir: str = None): - if os.path.exists(".env"): - from dotenv import load_dotenv - load_dotenv(".env") - if plugin_dir: - os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir - load_plugins() - config = load_config(config_path) - - if dlc: - cluster_namespace = f"{config.project}-{config.name}" - config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace) - - if not is_running(): - raise RuntimeError("Ray is not running, please start it by `ray start --head`.") - - try: - from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path - - if config.stages: - state_manager = StateManager( - path=os.path.join(config.checkpoint_root_dir, config.project, config.name) - ) - latest_stage = state_manager.load_stage().get("latest_stage", 0) - prev_stage_checkpoint = None - for i, stage_config in enumerate(config): - if i < latest_stage: - logger.info( - "===========================================================\n" - f"> Skipping completed stage {i + 1}/{len(config.stages)}...\n" - "===========================================================" - ) - else: - logger.info( - "===========================================================\n" - f"> Starting stage {i + 1}/{len(config.stages)}...\n" - "===========================================================" - ) - state_manager.save_stage(i) - if prev_stage_checkpoint is not None: - stage_config.model.model_path = prev_stage_checkpoint - stage_config.check_and_update() - run_stage(stage_config) - logger.info( - "===========================================================\n" - f"> Stage {i + 1}/{len(config.stages)} finished.\n" - "===========================================================" - ) - prev_stage_checkpoint = get_latest_hf_checkpoint_path(stage_config) - else: - config.check_and_update() - run_stage(config) - - finally: - if dlc: - stop_ray_cluster(namespace=cluster_namespace) - - -def studio(port: int = 8501): - from streamlit.web import cli as stcli - - current_dir = Path(__file__).resolve().parent.parent - config_manager_path = os.path.join(current_dir, "manager", "config_manager.py") - - sys.argv = [ - "streamlit", - "run", - config_manager_path, - "--server.port", - str(port), - "--server.fileWatcherType", - "none", - ] - sys.exit(stcli.main()) - - -def debug( - config_path: str, - module: str, - output_file: str = "debug_workflow_runner.html", - plugin_dir: str = None, -): - """Debug a module.""" - if plugin_dir: - os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir - load_plugins() - config = load_config(config_path) - config.check_and_update() - config.ray_namespace = DEBUG_NAMESPACE - ray.init( - namespace=config.ray_namespace, - runtime_env={"env_vars": config.get_envs()}, - ignore_reinit_error=True, - ) - from trinity.common.models import create_debug_inference_model - - if module == "inference_model": - create_debug_inference_model(config) - - elif module == "workflow": - from trinity.explorer.workflow_runner import DebugWorkflowRunner - - runner = DebugWorkflowRunner(config, output_file) - asyncio.run(runner.debug()) - else: - raise ValueError( - f"Only support 'inference_model' and 'workflow' for debugging, got {module}" - ) - - -def main() -> None: - """The main entrypoint.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(dest="command", required=True) - - # run command - run_parser = subparsers.add_parser("run", help="Run RFT process.") - run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.") - run_parser.add_argument( - "--plugin-dir", - type=str, - default=None, - help="Path to the directory containing plugin modules.", - ) - run_parser.add_argument( - "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." - ) - - # studio command - studio_parser = subparsers.add_parser("studio", help="Run studio.") - studio_parser.add_argument( - "--port", type=int, default=8501, help="The port for Trinity-Studio." - ) - - # debug command - debug_parser = subparsers.add_parser("debug", help="Debug the code.") - debug_parser.add_argument("--config", type=str, help="Path to the config file.") - debug_parser.add_argument( - "--module", - type=str, - choices=["inference_model", "workflow"], - help="The module to start debugging, only support 'inference_model' and 'workflow' for now.", - ) - debug_parser.add_argument( - "--plugin-dir", - type=str, - default=None, - help="Path to the directory containing plugin modules.", - ) - debug_parser.add_argument( - "--output-file", - type=str, - default="debug_workflow_runner.html", - help="The output file for viztracer.", - ) - - args = parser.parse_args() - if args.command == "run": - # TODO: support parse all args from command line - run(args.config, args.dlc, args.plugin_dir) - elif args.command == "studio": - studio(args.port) - elif args.command == "debug": - debug(args.config, args.module, args.output_file, args.plugin_dir) - - -if __name__ == "__main__": - main() diff --git a/astune/main_vllm.py b/astune/main_vllm.py deleted file mode 100644 index 085ef600..00000000 --- a/astune/main_vllm.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -import sys -import hydra - -from openai import OpenAI -from types import SimpleNamespace -from astune.schema.task import Task -from beast_logger import register_logger -from astune.parallel_env import ParallelEnvManager - -class TokenAndProb: - def __init__(self, t): - # ChatCompletionTokenLogprob(token='token_id:73594', bytes=[96, 96, 96], logprob=-1.9073468138230965e-06, top_logprobs=[]) - self.token_id = int(t.token.split('token_id:')[-1]) - self.logprob = t.logprob - try: - self.decoded_string = bytes(t.bytes).decode('utf-8') - except: - self.decoded_string = '' + str(t.bytes) - -class ChatCompletionScheduler(): - - def __init__(self, url, config): - from transformers import AutoTokenizer - self.url = url - self.config = config - self.tokenizer = AutoTokenizer.from_pretrained(self.config.astune.model.path) - self.chat_scheduler = SimpleNamespace( - model_name="dummy-model-name", - weighted_addresses="dummy-weighted-addresses", - completion_callback=SimpleNamespace(tokenizer=self.tokenizer), - ) - - def submit_chat_completions(self, messages, sampling_params, request_id): - client = OpenAI( - base_url=self.url, - api_key="token-abc123", - ) - sampling_params = dict( - n=1, - max_completion_tokens=self.config.astune.rollout.max_response_length_in_one_turn, - temperature=self.config.astune.rollout.temperature, - top_p=self.config.astune.rollout.top_p - ) - sampling_params["temperature"] = self.config.astune.rollout.val_kwargs.temperature - sampling_params["top_k"] = self.config.astune.rollout.val_kwargs.top_k - sampling_params["top_p"] = self.config.astune.rollout.val_kwargs.top_p - sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) - - completion = client.chat.completions.create( - model=self.config.astune.model.path, - messages=messages, - extra_body=sampling_params - ) - - message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) - if "content" not in message: message["content"] = "" - t = {"role": message["role"], "request_id":completion.id, "content": message['content'], "tokens": [TokenAndProb(t) for t in completion.choices[0].logprobs.content]} - messages.append(t) - - return messages - - -def run(config): - # --------- fast adjustment for debugging --------- - max_parallel = config.astune.debug.debug_max_parallel - n_task = config.astune.debug.debug_first_n_tasks - vllm_port = config.astune.debug.debug_vllm_port - - # --------- init --------- - async_rollout_manager = ChatCompletionScheduler(config=config, url=f"http://localhost:{vllm_port}/v1") - parallel_env = ParallelEnvManager( - config=config, - async_rollout_manager=async_rollout_manager, - max_parallel=max_parallel, - max_llm_retries=3, - llm_mode="remote", - tokenizer=async_rollout_manager.tokenizer - ) - - from astune.task_reader.task_reader_base import TaskReaderRouter - task_reader = TaskReaderRouter(config) - tasks = task_reader.get_validation_tasks() - print(tasks[:2]) - cmt = parallel_env.rollout(tasks=tasks[:n_task], mode="sample", epoch='1') # "sample" or "validate" - gen_batch_output = parallel_env.to_dataproto(cmt) - print("Generated batch output") - - -@hydra.main(config_path="astune/default_config", config_name="astune_default", version_base=None) -def main(config): - from omegaconf import OmegaConf - OmegaConf.resolve(config) - print('*' * 20) - - def companion_launch(): - from astune.utils.smart_daemon import LaunchCommandWhenAbsent - import torch - print("Launching companion process for async LLM server...") - model_path = config.astune.model.path - tensor_parallel_size = config.astune.debug.debug_tensor_parallel_size - n_avail_gpus = torch.cuda.device_count() - if tensor_parallel_size > n_avail_gpus: - print(f"Warning: tensor_parallel_size {tensor_parallel_size} is greater than available GPUs {n_avail_gpus}. Setting tensor_parallel_size to {n_avail_gpus}.") - tensor_parallel_size = n_avail_gpus - gpu_memory_utilization = config.actor_rollout_ref.rollout.gpu_memory_utilization - max_num_seqs = config.actor_rollout_ref.rollout.max_num_seqs - max_model_len = config.astune.rollout.max_model_len - seed = config.astune.debug.debug_vllm_seed - vllm_port = config.astune.debug.debug_vllm_port - companion = LaunchCommandWhenAbsent( - full_argument_list=[ - sys.executable, "-m", - f"vllm.entrypoints.cli.main", - f"serve", f"{model_path}", - f"--tensor-parallel-size", f"{tensor_parallel_size}", - f"--dtype", f"auto", - f"--enforce-eager", - f"--gpu-memory-utilization", f"{gpu_memory_utilization}", - f"--disable-custom-all-reduce", - f"--max-num-seqs", f"{max_num_seqs}", - f"--max-model-len", f"{max_model_len}", - f"--load-format", "auto", - f"--enable-chunked-prefill", - f"--enable-prefix-caching", - f"--seed", f"{seed}", - f"--port", f"{vllm_port}", - ], - dir='./', - tag="external_vllm_server" - ) - companion.launch(launch_wait_time=1800, success_std_string="Application startup complete", env_dict={**os.environ}) - companion_launch() - - run(config) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/astune/parallel_env.py b/astune/parallel_env.py deleted file mode 100644 index f3648aee..00000000 --- a/astune/parallel_env.py +++ /dev/null @@ -1,723 +0,0 @@ -import os -import copy -import time -import numpy as np -import torch -import uuid -from concurrent.futures import ThreadPoolExecutor -from typing import Dict, List, Literal, Callable, Any -from loguru import logger -from omegaconf import DictConfig -from tensordict import TensorDict -from torch.nn.utils.rnn import pad_sequence -from tqdm import tqdm -from verl import DataProto -from verl.utils.torch_functional import pad_sequence_to_length -from astune.agent_flow import AgentFlow -from astune.agent_flow import BaseAgentFlow -from astune.env import EnvWorker -from astune.schema.task import Task, TaskLaunchCoreArgument -from astune.schema.trajectory import Sample -from astune.context_manager.cmt_linear import CMTLinear, CMTBaseAttr -from beast_logger import register_logger, print_dict, print_listofdict -from astune.agentscope_flow import AgentScopeWorkflow -from astune.utils.utils import run_async_coro__no_matter_what_the_fuck -from pydantic import BaseModel, Field - - -def init_logger(experiment_name): - """Initialize the logger with the given configuration.""" - if 'BEST_LOGGER_INIT' in os.environ: return # prevent re-initialization in ray environment - os.environ['BEST_LOGGER_INIT'] = '1' - from datetime import datetime - final_log_path = os.path.join( "launcher_record", experiment_name, datetime.now().strftime("%Y_%m_%d_%H_%M") ) - os.environ['BEST_LOGGER_PATH'] = final_log_path - non_console_mods = ["rollout", "token_clip", "bad_case", "env_clip"] - register_logger(mods=["evaluation", "exception"], non_console_mods=non_console_mods, auto_clean_mods=[], base_log_path=final_log_path, debug=False) - -class AsyncLlmBridge(object): - - def __init__(self, config: DictConfig, async_rollout_manager, max_parallel: int, - max_llm_retries: int = 3, tokenizer: "AutoTokenizer"=None, llm_mode= "local", **kwargs): - - init_logger(experiment_name=config.astune.experiment_name) - self.llm_mode = llm_mode - self.config: DictConfig = config - self.async_rollout_manager = async_rollout_manager - self.max_parallel: int = max_parallel - self.max_llm_retries: int = max_llm_retries - - self.rollout_n = config.astune.rollout.num_repeat - # self.model_name = self.async_rollout_manager.chat_scheduler.model_name - self.tokenizer = tokenizer - self.pad_token_id = self.tokenizer.pad_token_id - self.current_token = 0 - self.current_global_steps = "NA" - - - def get_llm_chat_fn(self, sampling_params: dict = None) -> Callable: - import asyncio, uuid - from astune.schema.logprob import TokenAndProb - def llm_chat(messages: List[Dict[str, str]], - custom_sampling_params: dict = None, - request_id: str = None) -> dict: - """ - input messages: [{"role": "system", "value": "..."}, {"role": "user", "value": "..."}] - output messages: [{"role": "assistant", "value": "..."}] - """ - # TODO: sending sampling_params to rollout server - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - - # updated_sampling_params.update({"logprobs": 1, "prompt_logprobs": 1}) - input_messages = copy.deepcopy(messages) - request_id = uuid.uuid4().hex - prompt_ids = self.tokenizer.apply_chat_template(input_messages, add_generation_prompt=True, tokenize=True) - - final_res = run_async_coro__no_matter_what_the_fuck(self.async_rollout_manager.generate( - request_id=request_id, - prompt_ids=prompt_ids, - sampling_params=updated_sampling_params, - ) - ) - - if self.config.astune.rollout.name == 'vllm': - token_array = final_res.outputs[0].token_ids - elif self.config.astune.rollout.name == 'sglang': - token_array = final_res - - decoded_text = self.tokenizer.decode(token_array) - # decoded_text = "Let's start by finding which API we need to use to interact with Simple Note.\n\nCode: - # ```python\nprint(apis.api_docs.show_api_descriptions(app_name='simple_note'))\n```<|im_end|>" - if decoded_text.endswith('<|im_end|>'): - decoded_text = decoded_text[:-len('<|im_end|>')] - # assert prompt_ids == final_res.prompt_token_ids - # assert final_res.outputs[0].text == decoded_text - # a = self.tokenizer.apply_chat_template( - # input_messages + [{"role": "assistant", "content": decoded_text}], - # add_generation_prompt=False, tokenize=True) - # b = prompt_ids + token_array - # assert all([aa==bb for aa,bb in zip(a,b)]) - return { - "role": "assistant", - "request_id": request_id, - "content": decoded_text, - "tokens": [ - TokenAndProb( - token_id=token, - logprob=-1, - decoded_string=self.tokenizer.decode(token) - ) - for token in token_array - ] - } - - def llm_chat_remote(messages: List[Dict[str, str]], - custom_sampling_params: dict = None, - request_id: str = None) -> dict: - """ - input messages: [{"role": "system", "value": "..."}, {"role": "user", "value": "..."}] - output messages: [{"role": "assistant", "value": "..."}] - """ - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - updated_sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) - input_messages = copy.deepcopy(messages) - for i in range(self.max_llm_retries): - try: - output_message = self.async_rollout_manager.submit_chat_completions(messages=input_messages, - sampling_params=updated_sampling_params, - request_id=request_id) - break - except Exception as e: - logger.bind(exception=True).exception(f"rollout_server.{i} error: {e.args}") - time.sleep(i + 1) - - return output_message[-1] - - def llm_chat_trinity(messages: List[Dict[str, str]], - custom_sampling_params: dict = {}, - request_id: str = "") -> dict: - """ - input messages: [{"role": "system", "value": "..."}, {"role": "user", "value": "..."}] - output messages: [{"role": "assistant", "value": "..."}] - """ - async def main(model_client): - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - updated_sampling_params.pop('min_tokens') - response = await model_client.chat.completions.create( - model=model_client.model_path, - messages=messages, - logprobs=True, - top_logprobs=0, - **updated_sampling_params - ) - return response - - assert hasattr(self, 'trinity_llm_model_client'), "trinity_llm_model_client is not set in AsyncLlmBridge" - response = run_async_coro__no_matter_what_the_fuck(main(self.trinity_llm_model_client)) # type: ignore - from vsdb import bp - bp('INFER') - return { - "role": "assistant", - "request_id": response.id, - "content": response.choices[0].message.content, - "tokens": [ - TokenAndProb( - token_id=token, - logprob=tokenlogprob.logprob, - decoded_string=tokenlogprob.token - ) - for tokenlogprob, token in zip(response.choices[0].logprobs.content, response.choices[0].token_ids) - ] - } - - if self.llm_mode == "remote": - return llm_chat_remote - if self.llm_mode == "trinity": - return llm_chat_trinity - else: - return llm_chat - - -class StepPrinter(AsyncLlmBridge): - - def step_status_printer(self, obs_window): - # 直方数据,tmux 0~10 数量 10~20 数量 20~30 数量 30~40 数量 …… - step_counter = {} - - current_token = sum(obs_window['token']) - current_time = time.time() - delta_token = current_token - self.current_token - if delta_token < 0: delta_token = current_token # 下一次rollout开始了,tmux['token']会清零,简单处理一下就好 - delta_time = current_time - self.current_token_count_time - self.current_token = current_token - self.current_token_count_time = current_time - token_gen_per_sec_str = f"{delta_token/delta_time:.2f} tokens/s" if delta_time > 0 else "N/A" - - - for step in obs_window['step']: - if step == -1: - step_counter[(-1, 'terminated')] = step_counter.get((-1, 'terminated'), 0) + 1 - continue - else: - start = (step // 5) * 5 - end = start + 5 - step_counter[(start, end)] = step_counter.get((start, end), 0) + 1 - - # sort by start value (small to large) - step_counter = dict(sorted(step_counter.items(), key=lambda x: x[0][0])) - - print_buf = [] - for (start, end), count in step_counter.items(): - if start != -1: - print_buf += [f"[{start}-{end}]:{count} threads"] - for (start, end), count in step_counter.items(): - if start == -1: - print_buf += [f"[finished]:{count} threads"] - print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf)) - -class StaticRollout(StepPrinter, AsyncLlmBridge): - - def rollout_env_worker(self, task: Task, task_batch_index: int, task_tag: str, mode: Literal["sample", "validate"], - task_thread_index: int, obs_window: dict, **kwargs) -> CMTLinear: - """ - Process a single prompt in a thread-safe way. - """ - def get_sample_params(): - response_length_eps = 6 # 减少几个token给lm_start等special token的后续处理留余地 - if self.config.astune.rollout.name == 'vllm': - sampling_params = dict( - n=1, - max_tokens=self.config.astune.rollout.max_response_length_in_one_turn - response_length_eps, - min_tokens=1, # 必须至少输出1个token - temperature=self.config.astune.rollout.temperature, - top_p=self.config.astune.rollout.top_p - ) - else: - sampling_params = dict( - n=1, - max_new_tokens=self.config.astune.rollout.max_response_length_in_one_turn, - temperature=self.config.astune.rollout.temperature, - top_p=self.config.astune.rollout.top_p - ) - - if mode == "validate": - sampling_params["temperature"] = self.config.astune.rollout.val_kwargs.temperature - sampling_params["top_k"] = self.config.astune.rollout.val_kwargs.top_k - sampling_params["top_p"] = self.config.astune.rollout.val_kwargs.top_p - return sampling_params - - - max_retry = 3 - for retry in range(max_retry): - try: - llm_chat_fn = self.get_llm_chat_fn(get_sample_params()) - cmt: CMTBaseAttr = EnvWorker( - task_core_arg=TaskLaunchCoreArgument( - env_type=task.env_type, - task_id=task.task_id, - task_thread_index=task_thread_index, - task_batch_index=task_batch_index, - task_env_uuid=uuid.uuid4().hex, - task_tag=task_tag, - obs_window=obs_window, - llm_chat_fn=llm_chat_fn, - tokenizer=self.tokenizer, - task=task - ), - config=self.config - ).execute() - break - except Exception as e: - if retry < max_retry - 1: - logger.bind(exception=True).exception(f"rollout_env_worker error: {e.args}, retrying {retry + 1}/{max_retry}") - time.sleep(2 ** retry) - else: - logger.bind(exception=True).exception(f"rollout_env_worker failed after {max_retry} retries: {e.args}") - raise e - - return cmt # type: ignore - - - def rollout(self, tasks: List[Task], mode: Literal["sample", "validate"], epoch: str) -> List[CMTLinear]: - # 1. if enable_oversample - self.current_token_count_time = time.time() - # 2. otherwise, use legacy rollout method - cmt_array: List[CMTLinear] = [] - rollout_n = 1 if mode=="validate" else self.rollout_n - obs_window = { - 'step': [0 for _ in range(len(tasks) * rollout_n)], - 'token': [0 for _ in range(len(tasks) * rollout_n)], - 'stop': [False for _ in range(len(tasks) * rollout_n)], - } - with ThreadPoolExecutor(max_workers=self.max_parallel) as executor: - futures = [] - for task_batch_index, task in enumerate(tasks): - for task_rollout_index in range(rollout_n): - task_thread_index = task_batch_index * rollout_n + task_rollout_index - future = executor.submit(self.rollout_env_worker, - task=task, task_batch_index=task_batch_index, - task_tag=f"T{task.task_id}#R{task_rollout_index}", - mode=mode, - task_thread_index=task_thread_index, - obs_window=obs_window) - futures.append(future) - - while any(future.running() for future in futures): - self.step_status_printer(obs_window) - time.sleep(10) - - for future in tqdm(futures, desc=f"epoch{epoch}.collect_rollout"): - # do not fail silently - result = future.result() - cmt_array.append(result) - - task_success_rate = np.mean([cmt.reward_structure.success_rate for cmt in cmt_array]) - for cmt in cmt_array: - cmt.current_batch_success_rate = float(task_success_rate) - - return cmt_array - - -class DynamicRollout(StaticRollout): - - def rollout(self, tasks: List[Task], mode: Literal["sample", "validate"], epoch: str) -> List[CMTLinear]: - if mode=="sample" and (self.rollout_n!=1) and self.config.astune.rollout.enable_oversample: - return self.rollout_dynamic(tasks, mode, epoch) - else: - return super().rollout(tasks, mode, epoch) - - def greedy_max_std_selection(self, samples: List[CMTLinear], n): - if len(samples) < n: - additional_n = n - len(samples) - n = len(samples) - else: - additional_n = 0 - - sorted_samples = sorted(samples, key=lambda cmt: abs(cmt.reward_structure.performance_reward)) - value_array = [cmt.reward_structure.performance_reward for cmt in sorted_samples] - macro_selected_value = [] - macro_selected_index = [] - while len(macro_selected_index) != n: - selected_value = [] - selected_index = [] - for index, value in enumerate(value_array): - if (value not in selected_value) and (index not in macro_selected_index): - selected_value.append(value) - selected_index.append(index) - - if len(selected_value) + len(macro_selected_value) <= n: - macro_selected_value += selected_value - macro_selected_index += selected_index - - elif len(selected_value) + len(macro_selected_value) > n: - preserve_n = n - len(macro_selected_value) - # 从 selected_value 和 selected_index 两端选择 preserve_n 个样本 - pick_left = preserve_n // 2 - pick_right = preserve_n - pick_left - macro_selected_value += selected_value[:pick_left] + selected_value[-pick_right:] - macro_selected_index += selected_index[:pick_left] + selected_index[-pick_right:] - - if additional_n > 0: - # randomly select `additional_n` samples from `macro_selected_index`, then concat to `macro_selected_index` - additional_indices = np.random.choice(macro_selected_index, additional_n, replace=True) - macro_selected_index += additional_indices.tolist() - - selected_samples = [sorted_samples[i] for i in macro_selected_index] - sorted_selected_samples = sorted(selected_samples, key=lambda cmt: abs(cmt.reward_structure.performance_reward)) - return sorted_selected_samples - - - def rollout_dynamic(self, tasks: List[Task], mode: Literal["sample", "validate"], epoch: str, allow_sample_num_change=True, allow_force_stop=True) -> List[CMTLinear]: - """ - Rollout more - """ - cmt_array: List[CMTLinear] = [] - assert mode != "validate" - rollout_n = self.rollout_n - self.current_token_count_time = time.time() - submit_oversample_multiplier = self.config.astune.rollout.submit_oversample_multiplier - rollout_n_oversample = int(rollout_n * submit_oversample_multiplier) - rollout_n_confirm = int(rollout_n * (1 + submit_oversample_multiplier) / 2) - assert rollout_n < rollout_n_confirm < rollout_n_oversample, \ - f"submit_oversample_multiplier is too small, rollout_n={rollout_n}, rollout_n_confirm={rollout_n_confirm}, rollout_n_oversample={rollout_n_oversample}" - - obs_window = { - 'step': [0 for _ in range(len(tasks) * rollout_n_oversample)], - 'stop': [False for _ in range(len(tasks) * rollout_n_oversample)], - 'token': [0 for _ in range(len(tasks) * rollout_n_oversample)], - } - - with ThreadPoolExecutor(max_workers=self.max_parallel) as executor: - # 提交线程 - futures = [] - for task_batch_index, task in enumerate(tasks): - task_future_array = [] - for task_rollout_index in range(rollout_n_oversample): - task_thread_index = task_batch_index * rollout_n_oversample + task_rollout_index - future = executor.submit(self.rollout_env_worker, - task=task, - task_batch_index=task_batch_index, - task_tag=f"T{task.task_id}#R{task_rollout_index}", # task_rollout_index=str(task_rollout_index), - mode=mode, - task_thread_index=task_thread_index, - obs_window=obs_window) - task_future_array.append(future) - futures += [task_future_array] - - tic = -1 - # 记录已完成线程的结果 - while True: - tic += 1 - can_terminate = [False for _ in futures] - terminate_status = ['running' for _ in futures] - for j, task_future_array in enumerate(futures): - completed_task_futures = [f for f in task_future_array if f.done()] - completed_results = [f.result() for f in completed_task_futures] - completed_results = [cmt for cmt in completed_results if not cmt.discarded] - reward = [cmt.reward_structure.performance_reward for cmt in completed_results] - reward_std = np.std(reward) if reward else 0.0 - all_finished = (len(completed_task_futures) == len(task_future_array)) - # finish condition 1: all oversample tasks are finished - if all_finished: - can_terminate[j] = True - terminate_status[j] = f'all_fin({len(completed_results)}/{reward_std:.2f})' - num_finished = len(completed_task_futures) - task_cmd_reward_array = [cmt.reward_structure.performance_reward for cmt in completed_results] - all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) - # all_reward_greater_than_one = all(x >= 1 for x in task_cmd_reward_array) - if not all_equal: - if (num_finished >= rollout_n): - # finish condition 2: more than rollout_n tasks are finished, and, reward are not all equal - can_terminate[j] = True - terminate_status[j] = f'early_end({len(completed_results)}/{reward_std:.2f})' - else: - pass # keep waiting - else: - if num_finished >= rollout_n_confirm: - # finish condition 3: if more than rollout_n_confirm tasks are finished, we can confirm this task is hopeless (or successful for certainty) - can_terminate[j] = True - terminate_status[j] = f'confirm_dummy({len(completed_results)}/{reward_std:.2f})' - # take actions to stop future rollout - if allow_force_stop: - for k in range(j*rollout_n_oversample, j*rollout_n_oversample + rollout_n_oversample): - obs_window['stop'][k] = True - else: - pass # keep waiting - # check global status - terminate_status = '/'.join(terminate_status) - if all(can_terminate): - logger.info(f"epoch{epoch}.collect_rollout: all tasks finished, exiting loop") - for i, stop_flag in enumerate(obs_window['stop']): obs_window['stop'][i] = True # all must stop now - break - else: - if tic % 10 == 0: - self.step_status_printer(obs_window) # print status every 10*5=50 seconds - logger.info(f"task complete {sum(can_terminate)}/{len(can_terminate)} tasks: {terminate_status}") - time.sleep(5) - # 等待所有线程完成或者被迫中止 - tic = -1 - while any(f.running() for task_future_array in futures for f in task_future_array): - tic += 1 - if tic % 10 == 0: logger.info('waiting final sync, this will not take long') - time.sleep(5) - - # 检查到底有多少thread完成了预定任务 - task_ineffective_thread_cnt = [] - task_completed_thread_cnt = [] - task_extra_thread_cnt = [] - task_need_amend = 0 - for j, task_future_array in enumerate(futures): - # get number of completed tasks - completed_task_futures = [f for f in task_future_array if f.done()] - completed_results = [f.result() for f in completed_task_futures] - completed_results = [cmt for cmt in completed_results if not cmt.discarded] - task_cmd_reward_array = [cmt.reward_structure.performance_reward for cmt in completed_results] - all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) - # 计数 - completed_task_cnt = len(completed_results) - if all_equal: - task_need_amend += 1 - task_completed_thread_cnt += [0] - task_extra_thread_cnt += [0] - task_ineffective_thread_cnt += [completed_task_cnt] - else: - task_need_amend += 0 - task_completed_thread_cnt += [completed_task_cnt] - task_extra_thread_cnt += [completed_task_cnt - rollout_n] - task_ineffective_thread_cnt += [0] - - logger.info(f"task_completed_thread_cnt: {task_completed_thread_cnt}") - logger.info(f"task_extra_thread_cnt: {task_extra_thread_cnt}") - - world_size = self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes - total_sample = sum(task_completed_thread_cnt) - if allow_sample_num_change and (total_sample > world_size*2): - # 允许样本数量变化,我们只需要返回的样本能够被 显卡数 整除即可 - # add_count = (world_size - total_sample % world_size) # 如果采用添加策略,需要添加的样本数 - add_count = 0 # 如果采用添加策略,需要添加的样本数 - num_task_to_amend = len(futures) # num_task - logger.info(f"allow_sample_num_change policy: world_size: {world_size}, total_sample {total_sample}, add_count: {add_count}, ") - # 选择 extra 最少的task进行补偿 - while add_count != 0: - _task_completed_thread_cnt_find_nonzero_min = [float('inf') if x <=0 else x for x in task_completed_thread_cnt] - min_extra_index = _task_completed_thread_cnt_find_nonzero_min.index(min(_task_completed_thread_cnt_find_nonzero_min)) - task_extra_thread_cnt[min_extra_index] += 1 - task_completed_thread_cnt[min_extra_index] += 1 - add_count -= 1 - # logger.info(f"_task_completed_thread_cnt_find_nonzero_min: {_task_completed_thread_cnt_find_nonzero_min}") - logger.info(f"task_completed_thread_cnt (after remove): {task_completed_thread_cnt}") - logger.info(f"task_extra_thread_cnt (after remove): {task_extra_thread_cnt}") - else: - # 不允许样本数量变化,尝试补偿 - num_task_max_to_amend = sum(task_extra_thread_cnt) // rollout_n - num_task_to_amend = min(num_task_max_to_amend, task_need_amend) - extra_num_thread_required = num_task_to_amend * rollout_n - remove_count = sum(task_extra_thread_cnt) - extra_num_thread_required - logger.info(f"forbid_sample_num_change policy: num_task_max_to_amend: {num_task_max_to_amend}, num_task_to_amend: {num_task_to_amend}, remove_count: {remove_count}, ") - - # 选择 extra 最多的task进行约束 - while remove_count != 0: - max_extra_index = task_extra_thread_cnt.index(max(task_extra_thread_cnt)) - assert task_extra_thread_cnt[max_extra_index] > 0, "task_extra_thread_cnt should be greater than 0" - task_extra_thread_cnt[max_extra_index] -= 1 - task_completed_thread_cnt[max_extra_index] -= 1 - remove_count -= 1 - logger.info(f"task_completed_thread_cnt (after remove): {task_completed_thread_cnt}") - logger.info(f"task_extra_thread_cnt (after remove): {task_extra_thread_cnt}") - - # 筛选出方差最高的样本 - cmt_array = [] - print_buffer = "" - task_success_rate = [] - for j, task_future_array, avail_extra_cnt in zip(range(len(futures)), futures, task_extra_thread_cnt): - # get number of completed tasks - completed_task_futures = [f for f in task_future_array if f.done()] - completed_results = [f.result() for f in completed_task_futures] - completed_results: List[CMTLinear] = [cmt for cmt in completed_results if not cmt.discarded] - task_cmd_reward_array = [cmt.reward_structure.performance_reward for cmt in completed_results] - success_rate_array = [cmt.reward_structure.success_rate for cmt in completed_results] - task_success_rate += [np.mean(success_rate_array)] - need_amend = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) - if need_amend and (num_task_to_amend > 0): - num_task_to_amend -= 1 - print_buffer += f"/(amend)" - continue - else: - if need_amend: - num_completed = len(completed_results) - num_to_be_selected = rollout_n - else: - num_completed = len(completed_results) - num_to_be_selected = rollout_n + avail_extra_cnt - # assert num_completed >= num_to_be_selected, f"num_completed={num_completed}, num_to_be_selected={num_to_be_selected}" - selected_cmt_array = self.greedy_max_std_selection(completed_results, num_to_be_selected) - cmt_array += selected_cmt_array - print_buffer += f"/({len(selected_cmt_array)})" - if need_amend: print_buffer += "(no-amend)" - logger.info(print_buffer) - - for cmt in cmt_array: - cmt.current_batch_success_rate = np.mean(task_success_rate) - return cmt_array - - - - -class ParallelEnvManager(DynamicRollout): - - # TODO: define an extra class for trajectory-dataproto converting. - def to_dataproto(self, cmt_array) -> DataProto: - """Convert trajectories to DataProto""" - # Step 1: Convert trajectories to samples: tokenizing - samples = self.trajectories_to_samples(cmt_array) - - # Step 2: Convert samples to DataProto: padding - dataproto = self.samples_to_dataproto(samples) - - return dataproto - - def trajectories_to_samples(self, cmt_array: List[CMTLinear]) -> List[Sample]: - """Convert trajectories to samples""" - sample_arr_final = [] - CMTLinear.compute_reference_advantage(cmt_array) - for cmt in cmt_array: - try: - sample_arr = cmt.group_tokenize() - except Exception as e: - raise e - finally: - cmt.generate_log(global_step=self.current_global_steps) - if os.environ.get('BEST_LOGGER_PATH', None) and os.environ.get('ASTUNE_DEBUG', None): - logger.success(f"View rollout details at [https://localhost:8181/?path={os.path.abspath(os.environ['BEST_LOGGER_PATH'])}]") - sample_arr_final += sample_arr - - # Step 2: Calculate how many samples need to be removed - world_size = self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes - remainder = len(sample_arr_final) % world_size - if remainder != 0: - import random - remove_indices = random.sample(range(len(sample_arr_final)), remainder) - # Sort in reverse order to avoid index shifting during removal - remove_indices.sort(reverse=True) - for idx in remove_indices: - sample_arr_final.pop(idx) - - # random remove some samples, so that the number of samples is divisible by 8 - return sample_arr_final - - def samples_to_dataproto(self, samples: list[Sample]) -> DataProto: - # Initialize lists to store batched data - prompt_ids, response_ids = [], [] - prompt_attention_mask, response_attention_mask = [], [] - prompt_position_ids, response_position_ids = [], [] - prompt_loss_mask, response_loss_mask = [], [] - messages = [] - # reward_scores = [] # replace with step_reward_scores - step_reward_scores = [] - task_ids = [] - rollout_ids = [] - reference_advantage = [] - - for sample in samples: - # Validate that all fields have the same length - assert len(sample.input_ids) == len(sample.attention_mask) == len(sample.position_ids) == len( - sample.loss_mask), f"Sample {sample.request_id} has mismatched lengths: " \ - f"{len(sample.input_ids)=}, {len(sample.attention_mask)=}, " \ - f"{len(sample.position_ids)=}, {len(sample.loss_mask)=}" - - task_ids.append(sample.task_id) - rollout_ids.append(sample.task_tag) - # Discard samples with prompt length exceeding limit - if len(sample.prompt_ids) > self.config.astune.data.max_prompt_length: - raise RuntimeError(f"Sample has prompt_ids length {len(sample.prompt_ids)} ") - - # Warn if response is longer than expected (but still include it) - if len(sample.response_ids) > self.config.astune.data.max_response_length: - raise RuntimeError(f"Sample has prompt_ids length {len(sample.prompt_ids)} ") - - # Append tensors to respective lists - assert len(sample.prompt_ids) != 0 - assert len(sample.response_ids) != 0 - prompt_ids.append(torch.tensor(sample.prompt_ids, dtype=torch.int)) - response_ids.append(torch.tensor(sample.response_ids, dtype=torch.int)) - - prompt_attention_mask.append(torch.tensor(sample.prompt_attention_mask, dtype=torch.int)) - response_attention_mask.append(torch.tensor(sample.response_attention_mask, dtype=torch.int)) - - prompt_position_ids.append(torch.tensor(sample.prompt_position_ids, dtype=torch.int)) - response_position_ids.append(torch.tensor(sample.response_position_ids, dtype=torch.int)) - - prompt_loss_mask.append(torch.tensor(sample.prompt_loss_mask, dtype=torch.int)) - response_loss_mask.append(torch.tensor(sample.response_loss_mask, dtype=torch.int)) - - reference_advantage.append(sample.reference_advantage) - - messages.append({"messages": sample.messages}) - # reward_scores.append(sample.global_reward) - step_reward_scores.append(sample.step_reward) - - max_prompt_length_this_batch = max([p.shape[-1] for p in prompt_ids]) - assert max_prompt_length_this_batch <= self.config.astune.data.max_prompt_length - max_response_length_this_batch = max([p.shape[-1] for p in response_ids]) - assert max_response_length_this_batch <= self.config.astune.data.max_response_length - - # Batch and pad sequences - prompt_ids = pad_sequence(prompt_ids, batch_first=True, padding_value=self.pad_token_id, padding_side="left") - prompt_attention_mask = pad_sequence(prompt_attention_mask, batch_first=True, padding_value=0, padding_side="left") - prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left") - prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") - - prompt_ids = pad_sequence_to_length(prompt_ids, max_prompt_length_this_batch, self.pad_token_id, left_pad=True) - prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, max_prompt_length_this_batch, 0, left_pad=True) - prompt_position_ids = pad_sequence_to_length(prompt_position_ids, max_prompt_length_this_batch, 0, left_pad=True) - prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, max_prompt_length_this_batch, 0, left_pad=True) - - response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) - response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) - response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) - - response_ids = pad_sequence_to_length(response_ids, max_response_length_this_batch, self.pad_token_id) - response_attention_mask = pad_sequence_to_length(response_attention_mask, max_response_length_this_batch, 0) - response_loss_mask = pad_sequence_to_length(response_loss_mask, max_response_length_this_batch, 0) - - delta_position_id = torch.arange(1, response_ids.size(1) + 1, device=response_ids.device).unsqueeze(0).repeat(len(samples), 1) - response_position_ids = prompt_position_ids[:, -1:] + delta_position_id - - # Concatenate prompt and response tensors - input_ids = torch.cat((prompt_ids, response_ids), dim=-1) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) - position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) - loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1) - - # Construct the batch using TensorDict - batch = TensorDict( - { - "prompts": prompt_ids, - "responses": response_ids, - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "loss_mask": loss_mask, - }, - batch_size=len(samples), - ) - - return DataProto(batch=batch, non_tensor_batch={ - "task_ids": np.array(task_ids), - "rollout_ids": np.array(rollout_ids), - "messages": np.array(messages), - "reward_scores": np.array(step_reward_scores), - "reference_advantage": np.array(reference_advantage), - }) \ No newline at end of file diff --git a/astune/protocol/agentscope_protocol.py b/astune/protocol/agentscope_protocol.py deleted file mode 100644 index dc38d0f8..00000000 --- a/astune/protocol/agentscope_protocol.py +++ /dev/null @@ -1,26 +0,0 @@ -from agentscope.message import Msg -from pydantic import BaseModel, Field -from typing import Callable, List -try: from astune.agentscope_flow import BeyondAgentProxy -except ImportError: pass - -class AgentScopeLearnProtocol(BaseModel): - model_config = {"extra": "allow"} - # Trainer to use; default "trinity". Optional: "agentscorpion-trinity". - trainer: str = Field(default="trinity") - # Experiment name - agentflow_name: str = Field(default="agent-flow") - # In multi-agent settings, specify the list of trainable agent target names - trainable_agent_targets: List[str] = Field(default=[]) - # Use dataset provided by the trainer (True: read each query from workflow input; False: AgentScope handles each query) - external_dataset: bool = Field(default=True) - # Use external environment provided by the trainer (True: read environment handle from input; False: AgentScope runs environment and tools) - external_environment: bool = Field(default=True) - # Use external reward provided by the trainer (True: compute reward outside AgentScope after workflow; False: AgentScope computes reward) - external_reward: bool = Field(default=True) - # Other settings - multiturn_token_consolidation: bool = Field(default=True) - - async def agentscope_execute(self, init_messages, beyondagent_proxy: "BeyondAgentProxy", config)->"BeyondAgentProxy": - raise NotImplementedError - diff --git a/astune/schema/logprob.py b/astune/schema/logprob.py deleted file mode 100644 index b425fc45..00000000 --- a/astune/schema/logprob.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any, Dict, List - -from loguru import logger -from omegaconf import DictConfig -from openai.types.chat.chat_completion import ChatCompletion -from verl import DataProto - -class TokenAndProb: - def __init__(self, token_id, logprob, decoded_string): - self.token_id = token_id - self.logprob = logprob - self.decoded_string = decoded_string - diff --git a/astune/schema/task.py b/astune/schema/task.py deleted file mode 100644 index a45abf6d..00000000 --- a/astune/schema/task.py +++ /dev/null @@ -1,24 +0,0 @@ -from pydantic import BaseModel, Field -from typing import List, Dict, Any - - -class Task(BaseModel): - main_query: str = Field(default="") - init_messages: List[dict] = Field(default=[]) - task_id: str = Field(default="") - env_type: str = Field(default="") - metadata: dict = Field(default_factory=dict) - - -class TaskLaunchCoreArgument(BaseModel): - env_type: str = Field(default="") - task_id: str = Field(default="") - task_thread_index: int = Field(default=0) - task_batch_index: int = Field(default=0) - task_tag: str = Field(default="") - task_env_uuid: str = Field(default="") - obs_window: dict = Field(default={}) - llm_chat_fn: Any = Field(default=None) - tokenizer: Any = Field(default=None) - task: Task = Field(default=None) # type: ignore - diff --git a/astune/task_judge/env_service_as_judge.py b/astune/task_judge/env_service_as_judge.py deleted file mode 100644 index 770dd2e2..00000000 --- a/astune/task_judge/env_service_as_judge.py +++ /dev/null @@ -1,30 +0,0 @@ -from astune.task_judge.judge_base import JudgeBase - - -class EnvServiceJudge(JudgeBase): - - def __init__(self, config): - self.config = config - - def compute_reward(self, judge_input_dictionary) -> tuple: - raw_reward = 0 - - env = judge_input_dictionary['env'] - task_core_arg = judge_input_dictionary['task_core_arg'] - - raw_reward = env.evaluate(task_core_arg.task_env_uuid, params={"sparse": False}) - if raw_reward >= 1: - is_success = True - else: - is_success = False - - if self.config.astune.rollout.add_special_success_reward: - if is_success: - raw_reward = 1.0 + raw_reward * 0.5 - else: - raw_reward = 0.0 + raw_reward * 0.5 - - if self.config.astune.rollout.binary_reward: - raw_reward = 1.0 if is_success else 0.0 - - return raw_reward, is_success \ No newline at end of file diff --git a/astune/task_judge/judge_base.py b/astune/task_judge/judge_base.py deleted file mode 100644 index 5d18d9db..00000000 --- a/astune/task_judge/judge_base.py +++ /dev/null @@ -1,11 +0,0 @@ -class JudgeBase(): - - def __init__(self, config): - self.config = config - - def compute_reward(self, judge_input_dictionary) -> tuple: - # judge_input_dictionary['env']: env_service 外部环境 (如果使用了env_service) - # judge_input_dictionary['task_core_arg']: 任务信息(如果里面包含了参考答案,可以从中取出) - # judge_input_dictionary['grouped_steps']: LLM的每一次历史对话记录(如果中间过程比较重要,可以从中取出) - - raise NotImplementedError \ No newline at end of file diff --git a/astune/task_judge/math_answer_as_judge.py b/astune/task_judge/math_answer_as_judge.py deleted file mode 100644 index 4d6ee9da..00000000 --- a/astune/task_judge/math_answer_as_judge.py +++ /dev/null @@ -1,63 +0,0 @@ -from astune.task_judge.judge_base import JudgeBase -import re - -class MathAnswerAsJudge(JudgeBase): - - def __init__(self, config): - self.config = config - - def compute_reward(self, judge_input_dictionary) -> tuple: - # judge_input_dictionary['env']: env_service 外部环境 (如果使用了env_service) - # judge_input_dictionary['task_core_arg']: 任务信息(如果里面包含了参考答案,可以从中取出) - # judge_input_dictionary['grouped_steps']: LLM的每一次历史对话记录(如果中间过程比较重要,可以从中取出) - - raw_reward = 0 - final_answer = judge_input_dictionary['final_answer'] # 默认没有final_answer,需要在workflow中手动调用 beyondagent_proxy.update_judge_input_dictionary(final_answer=final_answer) 注册 - task_core_arg = judge_input_dictionary['task_core_arg'] - reference_answer = task_core_arg.task.metadata['answer'] - reference_answer = reference_answer.split('####')[-1].strip() - - pattern = r'\\boxed\{([^}]*)\}' - match = re.search(pattern, final_answer) - if match: - result = match.group(1) - is_success = result == reference_answer - else: - is_success = False - - raw_reward = 1.0 if is_success else 0.0 - return raw_reward, is_success - - -class MathAnswerAndLlmAsJudge(JudgeBase): - - def __init__(self, config): - self.config = config - - def compute_reward(self, judge_input_dictionary) -> tuple: - raw_reward = 0 - - final_answer = judge_input_dictionary['final_answer'] - task_core_arg = judge_input_dictionary['task_core_arg'] - reference_answer = task_core_arg.task.metadata['answer'] - - from astune.context_manager.cmt_foreign_llm import construct_alien_llm_chat_fn - alien_llm_chat_fn = construct_alien_llm_chat_fn(self.config) - messages = [ - { - 'role':'system', - 'content':f'Is my result correct? If correct, say , otherwise say .' - }, - { - 'role':'user', - 'content':f'Is my result correct?\n\n\n----\nMy result: {final_answer}\n\n\n----\nReal result: {reference_answer}' - } - ] - res = alien_llm_chat_fn(messages=messages) - if '' in res['content']: - is_success = True - raw_reward = 1.0 - else: - is_success = False - raw_reward = 0.0 - return raw_reward, is_success \ No newline at end of file diff --git a/astune/task_reader/task_reader_base.py b/astune/task_reader/task_reader_base.py deleted file mode 100644 index b0ee6ad0..00000000 --- a/astune/task_reader/task_reader_base.py +++ /dev/null @@ -1,224 +0,0 @@ -import json -import uuid -import torch -import datasets -from typing import List, Dict, Optional -from astune.schema.task import Task -from astune.utils.process_dataset import create_rl_dataset, create_rl_sampler -from astune.env_service_client.env_client_ng import EnvClient - - -class TaskReaderBase: - def __init__(self, config): - self.config = config - - def get_training_tasks(self)->List[Task]: - raise NotImplementedError - - def get_validation_tasks(self)->List[Task]: - raise NotImplementedError - - -class TaskReaderAppWorld(TaskReaderBase): - def __init__(self, config): - super().__init__(config) - - def get_tasks(self, split): - env_url = self.config.astune.task_reader.env_service.env_url - env_type = self.config.astune.task_reader.env_service.env_type - env_service_client = EnvClient(base_url=env_url) - task_id_array = env_service_client.get_env_profile(env_type, split=split) - if len(task_id_array) == 0: - raise ValueError(f"No task_id found for env_type: {env_type}, split: {split}, Please check connection to {env_url}") - tasks = [ - Task( - main_query='[not defined]', - init_messages=[], - task_id=str(task_id), - env_type=env_type, - metadata={}, - ) for task_id in task_id_array] - return tasks - - def get_validation_tasks(self): - split = self.config.astune.task_reader.env_service.validation_split - return self.get_tasks(split=split) - - def get_training_tasks(self): - split = self.config.astune.task_reader.env_service.training_split - return self.get_tasks(split=split) - - -class TaskReaderJsonl(TaskReaderBase): - def __init__(self, config): - super().__init__(config) - - def _read_jsonl_file(self, file_path): - """ - Read tasks from a JSONL file. - - Args: - file_path (str): Path to the JSONL file. - - Returns: - List[Task]: List of Task objects. - """ - tasks = [] - try: - with open(file_path, 'r', encoding='utf-8') as f: - for line in f: - if line.strip(): # Skip empty lines - task_data = json.loads(line) - # Create a Task object from the JSON data - task = Task( - main_query=task_data.get('main_query', '[not defined]'), - init_messages=task_data.get('init_messages', []), - task_id=task_data.get('task_id', ''), - env_type=task_data.get('env_type', 'no_env'), - metadata=task_data.get('metadata', {}) - ) - tasks.append(task) - except FileNotFoundError: - raise ValueError(f"JSONL file not found: {file_path}") - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in file {file_path}: {str(e)}") - - if len(tasks) == 0: - raise ValueError(f"No tasks found in file: {file_path}") - - return tasks - - def get_training_tasks(self) -> List[Task]: - """ - Get training tasks from the JSONL file specified in the config. - - Returns: - List[Task]: List of training Task objects. - """ - file_path = self.config.astune.task_reader.dataset_file.training.file_path - return self._read_jsonl_file(file_path) - - def get_validation_tasks(self) -> List[Task]: - """ - Get validation tasks from the JSONL file specified in the config. - - Returns: - List[Task]: List of validation Task objects. - """ - file_path = self.config.astune.task_reader.dataset_file.validation.file_path - return self._read_jsonl_file(file_path) - - -class TaskReaderHuggingFace(TaskReaderBase): - """ - Task reader that reads tasks from Hugging Face datasets. - - This class allows loading tasks directly from Hugging Face dataset repositories. - It supports configuring the dataset name and split names for training and validation. - """ - - def __init__(self, config): - super().__init__(config) - - - def _load_dataset_split(self, dataset_name: str, split: str) -> List[Task]: - """ - Load a dataset split from Hugging Face datasets. - - Args: - dataset_name: Name of the dataset in Hugging Face format (e.g., 'gsm8k') - split: Name of the split to load (e.g., 'train', 'validation') - - Returns: - List[Task]: List of Task objects created from the dataset. - """ - try: - dataset = datasets.load_dataset(dataset_name, split=split) - except Exception as e: - raise ValueError(f"Failed to load dataset '{dataset_name}' with split '{split}': {str(e)}") - - # if len(dataset) == 0: - # raise ValueError(f"No examples found in dataset '{dataset_name}' with split '{split}'") - - tasks = [] - for idx, example in enumerate(dataset): - # Create Task object - task = Task( - main_query=example['question'], - init_messages=[], # Dataset examples typically don't have init messages - task_id=str(idx), - env_type=f"no_env", - metadata=example, - ) - tasks.append(task) - - return tasks - - def get_training_tasks(self) -> List[Task]: - """ - Get training tasks from the Hugging Face dataset specified in the config. - - Returns: - List[Task]: List of training Task objects. - """ - dataset_name = self.config.astune.task_reader.huggingface_dat_repo.dataset_path - split = self.config.astune.task_reader.huggingface_dat_repo.training_split - return self._load_dataset_split(dataset_name, split) - - def get_validation_tasks(self) -> List[Task]: - """ - Get validation tasks from the Hugging Face dataset specified in the config. - - Returns: - List[Task]: List of validation Task objects. - """ - dataset_name = self.config.astune.task_reader.huggingface_dat_repo.dataset_path - split = self.config.astune.task_reader.huggingface_dat_repo.validation_split - return self._load_dataset_split(dataset_name, split) - - -class TaskReaderRouter(TaskReaderBase): - def __init__(self, config): - super().__init__(config) - self.task_reader_type = self.config.astune.task_reader.type - if self.task_reader_type == 'env_service': - self.task_reader = TaskReaderAppWorld(config) - elif self.task_reader_type == 'dataset_file': - self.task_reader = TaskReaderJsonl(config) - elif self.task_reader_type == 'huggingface_dat_repo': - self.task_reader = TaskReaderHuggingFace(config) - else: - raise ValueError(f"Unsupported task reader type: {self.task_reader_type}") - - def get_training_tasks(self) -> List[Task]: - return self.task_reader.get_training_tasks() - - def get_validation_tasks(self) -> List[Task]: - return self.task_reader.get_validation_tasks() - -def task_to_standard_dataset(tasks: List[Task]) -> datasets.Dataset: - """ - Convert a list of Task objects to a standard Hugging Face Dataset. - - Args: - tasks (List[Task]): List of Task objects. - - Returns: - datasets.Dataset: Hugging Face Dataset containing the tasks. - """ - data = { - 'task_id': [], - 'main_query': [], - 'init_messages': [], - 'env_type': [], - 'metadata': [], - } - - for task in tasks: - data['task_id'].append(task.task_id) - data['main_query'].append(task.main_query) - data['init_messages'].append(task.init_messages) - data['env_type'].append(task.env_type) - data['metadata'].append(task.metadata) - - return datasets.Dataset.from_dict(data) \ No newline at end of file diff --git a/astune/utils/analysis_time_cost.py b/astune/utils/analysis_time_cost.py deleted file mode 100644 index 0db25fdc..00000000 --- a/astune/utils/analysis_time_cost.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -import sys - - -def time_cost_analysis(file_path: str): - result_time = {} - result_len = {} - all_time = [] - all_len = [] - with open(file_path) as f: - for line in f: - if not line: - continue - - if "info_dict=" not in line: - continue - - if "time_cost" not in line: - continue - - line = line.split("info_dict=")[1] - line_dict = json.loads(line) - act_step = line_dict["act_step"] - time_cost = line_dict["time_cost"] - char_len = len(line_dict["llm_output"]["content"]) - - if act_step not in result_time: - result_time[act_step] = [] - result_time[act_step].append(time_cost) - all_time.append(time_cost) - - if act_step not in result_len: - result_len[act_step] = [] - result_len[act_step].append(char_len) - all_len.append(char_len) - - for k, v in sorted(result_time.items(), key=lambda x: x[0]): - len_list = result_len[k] - print(f"act_step={k} time_cost={sum(v) / len(v):.2f} " - f"count={len(v)} " - f"len={sum(len_list) / len(len_list):.2f} " - f"efficient={sum(v) * 1000 / sum(len_list):.2f}") - - print(f"time_cost={sum(all_time) / len(all_time):.2f} " - f"count={len(all_time)} " - f"len={sum(all_len) / len(all_len):.2f} " - f"efficient={sum(all_time) * 1000 / sum(all_len):.2f}") - - -if __name__ == "__main__": - for file in sys.argv[1:]: - time_cost_analysis(file) diff --git a/astune/utils/async_http_client.py b/astune/utils/async_http_client.py deleted file mode 100644 index 3b1eaa01..00000000 --- a/astune/utils/async_http_client.py +++ /dev/null @@ -1,79 +0,0 @@ -import asyncio -import time -from typing import Any - -import aiohttp -from loguru import logger -from pydantic import BaseModel, Field, PrivateAttr, model_validator - - -class AsyncHttpClient(BaseModel): - url: str = Field(default="") - keep_alive: bool = Field(default=False, description="if true, use session to keep long connection") - timeout: int = Field(default=300, description="request timeout, second") - return_default_if_error: bool = Field(default=True) - - request_start_time: float = Field(default_factory=time.time) - request_time_cost: float = Field(default=0.0, description="request time cost") - - retry_sleep_time: float = Field(default=0.5, description="interval time for retry") - retry_time_multiplier: float = Field(default=2.0, description="retry time multiplier") - retry_max_count: int = Field(default=1, description="maximum number of retries") - - _client: Any | aiohttp.ClientSession = PrivateAttr() - - @model_validator(mode="after") - def init_client(self): - self._client = aiohttp.ClientSession(timeout=self.timeout) if self.keep_alive else aiohttp - return self - - async def __aenter__(self): - return self - - async def __aexit__(self, *args): - await self.close() - self.request_time_cost: float = time.time() - self.request_start_time - - async def close(self): - if isinstance(self._client, aiohttp.ClientSession): - await self._client.close() - - def parse_result(self, response: aiohttp.ClientResponse | None = None, **kwargs): - return response.json() - - def return_default(self, **kwargs) -> Any: - return None - - async def request( - self, - data: str | Any = None, - json_data: dict = None, - headers: dict = None, - http_enum: str = "post", - **kwargs, - ) -> Any: - retry_sleep_time = self.retry_sleep_time - method = http_enum - - for i in range(self.retry_max_count): - try: - response = await self._client.request(method=method, url=self.url, data=data, json=json_data, - headers=headers) - - result = self.parse_result(response=response, data=data, json_data=json_data, headers=headers, - http_enum=http_enum, **kwargs) - return result - - except Exception as e: - logger.exception(f"{self.__class__.__name__} {i}th request failed with args={e.args}") - - if i == self.retry_max_count - 1: - if self.return_default_if_error: - return self.return_default() - else: - raise e - - retry_sleep_time *= self.retry_time_multiplier - await asyncio.sleep(retry_sleep_time) - - return None diff --git a/astune/utils/cleaner.py b/astune/utils/cleaner.py deleted file mode 100644 index 815f0977..00000000 --- a/astune/utils/cleaner.py +++ /dev/null @@ -1,61 +0,0 @@ - -import subprocess -import argparse -import shutil -import time -import sys -import os -import shlex - -def _fast_kill_by_keyword_bash(keyword: str, exclude_substrings=["vscode"], grace_seconds: float = 1.0): - """Use bash pipelines to kill processes matching keyword quickly. - - - Filters out processes containing any exclude_substrings - - Excludes current launcher process - - Sends TERM once to all PIDs, then KILL once to all PIDs after a short grace period - - Returns list of PIDs targeted - """ - self_pid = os.getpid() - - # Build a fast PID collector using pgrep if available; fallback to ps/grep - # We prefer pgrep -af to filter by full command and then extract PID (column 1) - exclude_filters = " ".join([f"| grep -v -F {shlex.quote(s)}" for s in exclude_substrings]) - pid_list_cmd = ( - f"(pgrep -af -- {shlex.quote(keyword)} 2>/dev/null || true) " - f"{exclude_filters} | awk '{{print $1}}' | grep -v -x {self_pid} || true" - ) - - try: - res = subprocess.run(["bash", "-lc", pid_list_cmd], capture_output=True, text=True, check=False) - pids = [pid for pid in res.stdout.split() if pid.isdigit()] - except Exception as e: - print(f"Failed to list PIDs via bash: {e}") - pids = [] - - # Fallback to ps/grep if pgrep path produced nothing (e.g., no pgrep installed) - if not pids: - ps_pid_cmd = ( - f"ps -eo pid,command -ww | grep -F -- {shlex.quote(keyword)} | grep -v grep " - f"{exclude_filters} | awk '{{print $1}}' | grep -v -x {self_pid} || true" - ) - try: - res2 = subprocess.run(["bash", "-lc", ps_pid_cmd], capture_output=True, text=True, check=False) - pids = [pid for pid in res2.stdout.split() if pid.isdigit()] - except Exception as e: - print(f"Failed to list PIDs via ps/grep: {e}") - pids = [] - - if not pids: - return [] - - pid_args = " ".join(pids) - try: - # Send TERM to all in one call - subprocess.run(["bash", "-lc", f"kill -TERM -- {pid_args} 2>/dev/null || true"], check=False) - time.sleep(grace_seconds) - # Escalate with KILL once; ignore failures for already-exited PIDs - subprocess.run(["bash", "-lc", f"kill -KILL -- {pid_args} 2>/dev/null || true"], check=False) - except Exception as e: - print(f"Error issuing kill commands: {e}") - - return [int(p) for p in pids] \ No newline at end of file diff --git a/astune/utils/http_client.py b/astune/utils/http_client.py deleted file mode 100644 index 53639f58..00000000 --- a/astune/utils/http_client.py +++ /dev/null @@ -1,151 +0,0 @@ -import http -import time -from typing import Any - -import requests -from loguru import logger -from pydantic import BaseModel, Field, PrivateAttr, model_validator - - -class HttpClient(BaseModel): - url: str = Field(default="") - keep_alive: bool = Field(default=False, description="if true, use session to keep long connection") - timeout: int = Field(default=300, description="request timeout, second") - return_default_if_error: bool = Field(default=True) - - request_start_time: float = Field(default_factory=time.time) - request_time_cost: float = Field(default=0.0, description="request time cost") - - retry_sleep_time: float = Field(default=0.5, description="interval time for retry") - retry_time_multiplier: float = Field(default=2.0, description="retry time multiplier") - retry_max_count: int = Field(default=1, description="maximum number of retries") - - _client: Any = PrivateAttr() - - @model_validator(mode="after") - def init_client(self): - self._client = requests.Session() if self.keep_alive else requests - return self - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - self.request_time_cost: float = time.time() - self.request_start_time - - def close(self): - if isinstance(self._client, requests.Session): - self._client.close() - - def _request(self, - data: str = None, - json_data: dict = None, - headers: dict = None, - stream: bool = False, - http_enum: str = "post"): - - if http_enum == "post": - response: requests.Response = self._client.post(url=self.url, - data=data, - json=json_data, - headers=headers, - stream=stream, - timeout=self.timeout) - - elif http_enum == "get": - response: requests.Response = self._client.get(url=self.url, - data=data, - json=json_data, - headers=headers, - stream=stream, - timeout=self.timeout) - - else: - raise NotImplementedError - - if response.status_code != http.HTTPStatus.OK: - raise RuntimeError(f"request failed! content={response.json()}") - - return response - - def parse_result(self, response: requests.Response | Any = None, **kwargs): - return response.json() - - def return_default(self, **kwargs): - return None - - def request(self, - data: str | Any = None, - json_data: dict = None, - headers: dict = None, - http_enum: str = "post", - **kwargs): - - retry_sleep_time = self.retry_sleep_time - for i in range(self.retry_max_count): - try: - response = self._request(data=data, json_data=json_data, headers=headers, http_enum=http_enum) - result = self.parse_result(response=response, - data=data, - json_data=json_data, - headers=headers, - http_enum=http_enum, - **kwargs) - return result - - except Exception as e: - logger.exception(f"{self.__class__.__name__} {i}th request failed with args={e.args}") - - if i == self.retry_max_count - 1: - if self.return_default_if_error: - return self.return_default() - else: - raise e - - retry_sleep_time *= self.retry_time_multiplier - time.sleep(retry_sleep_time) - - return None - - def request_stream(self, - data: str = None, - json_data: dict = None, - headers: dict = None, - http_enum: str = "post", - **kwargs): - - retry_sleep_time = self.retry_sleep_time - for i in range(self.retry_max_count): - try: - response = self._request(data=data, - json_data=json_data, - headers=headers, - stream=True, - http_enum=http_enum) - request_context = {} - for iter_idx, line in enumerate(response.iter_lines()): - yield self.parse_result(line=line, - request_context=request_context, - index=iter_idx, - data=data, - json_data=json_data, - headers=headers, - http_enum=http_enum, - **kwargs) - - return None - - except Exception as e: - logger.exception(f"{self.__class__.__name__} {i}th request failed with args={e.args}") - - if i == self.retry_max_count - 1: - if self.return_default_if_error: - return self.return_default() - else: - raise e - - retry_sleep_time *= self.retry_time_multiplier - time.sleep(retry_sleep_time) - - return None diff --git a/astune/utils/markdown_parser.py b/astune/utils/markdown_parser.py deleted file mode 100644 index 2a2667d2..00000000 --- a/astune/utils/markdown_parser.py +++ /dev/null @@ -1,38 +0,0 @@ -def read_markdown_and_extract_sections( - markdown_text, - expected_sections=["current step", "previous instruction code", "relevant environment feedback", "next-step instruction code"], - default_placeholder="❌ not available." -): - sections = {} - # if not isinstance(markdown_text, str): - # markdown_text = markdown_text.content_for_future - lines = markdown_text.splitlines() - current_section = None - - for line in lines: - line = line.strip() - if not line: - continue - if line.startswith("# "): - section_name = line[2:].strip().lower() - current_section = section_name - sections[current_section] = [] - elif current_section: - sections[current_section].append(line) - - for key in list(sections.keys()): - if key not in expected_sections: - sections.pop(key, None) - - section_to_return = {k: "\n".join(v) for k, v in sections.items()} - find_all_expected_sections = True - find_no_expected_sections = True - for section in expected_sections: - if section not in section_to_return: - section_to_return[section] = default_placeholder - find_all_expected_sections = False - else: - find_no_expected_sections = False - - - return section_to_return, find_all_expected_sections, find_no_expected_sections diff --git a/astune/utils/message.py b/astune/utils/message.py deleted file mode 100644 index a80ad546..00000000 --- a/astune/utils/message.py +++ /dev/null @@ -1,12 +0,0 @@ -import requests, os - -def send_train_message(message: str): - # 发送短信汇报训练进程 - assert len(message) < 64, f"Message too long: {(message)}" - if os.getenv("ALIYUN_SMS_SERVICE"): - try: requests.post( - json={"phone_numbers": "18810508767", "server_code": "DLC", "error": message, "error_level": "无"}, - url=os.getenv("ALIYUN_SMS_SERVICE", "http://localhost:8000/send-sms"), - headers={"Content-Type": "application/json"} - ) - except Exception as e: print(f"Failed to send sms: {e}") \ No newline at end of file diff --git a/astune/utils/model_merger.py b/astune/utils/model_merger.py deleted file mode 100644 index ec3983c1..00000000 --- a/astune/utils/model_merger.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. - -To merge FSDP checkpoints: -```sh -python scripts/model_merger.py merge \ - --backend fsdp \ - --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -To merge Megatron checkpoints: -```sh -python scripts/model_merger.py merge \ - --backend megatron \ - --tie-word-embedding \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -For more details, please refer to documentation: -https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model -""" - -import argparse -import os -import re -import warnings -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional, Union - -import numpy as np -import torch -from accelerate import init_empty_weights -from safetensors.torch import load_file -from torch.distributed._tensor import Placement, Shard -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForTokenClassification, - AutoModelForVision2Seq, - GenerationConfig, - PretrainedConfig, -) - -try: - # for torch 2.5+ - from torch.distributed.tensor import DTensor -except ImportError: - from torch.distributed._tensor import DTensor - -from tqdm import tqdm - -from verl.utils import hf_processor, hf_tokenizer - - -@dataclass -class ModelMergerConfig: - operation: str # 'merge' or 'test' - backend: str - local_dir: str - hf_model_config_path: str - target_dir: Optional[str] = "tmp" - hf_upload_path: Optional[str] = None - private: bool = False - test_hf_dir: Optional[str] = None - tie_word_embedding: bool = False - is_value_model: bool = False - hf_model_path: Optional[str] = None - hf_upload: bool = field(init=False) - - def __post_init__(self): - self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) - if self.operation == "test": - self.target_dir = None - self.hf_upload_path = None - self.private = False - - -class BaseModelMerger(ABC): - def __init__(self, config: ModelMergerConfig): - self.config = config - self.hf_model_config_path = config.hf_model_config_path - - if config.hf_model_path: - print("Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. ") - self.hf_model_config_path = config.hf_model_path - - self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path) - - def get_transformers_auto_model_class(self): - if "ForTokenClassification" in self.model_config.architectures[0]: - return AutoModelForTokenClassification - elif "ForCausalLM" in self.model_config.architectures[0]: - return AutoModelForCausalLM - elif "ForConditionalGeneration" in self.model_config.architectures[0]: - return AutoModelForVision2Seq - - raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") - - def patch_model_generation_config(self, model): - """ - The generation_config created from model config may be different to the pretrained model, - this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 - - This function patch the generation_config created from model config to the pretrained model. - """ - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) - except OSError: - print(f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config.") - return model - - def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): - """ - Save lora adapter to safetensors. - - Returns: - lora_path: str, the path to the lora adapter. None if no lora adapter found. - - Note: - This function change the 'state_dict' in place. - """ - lora_params_names = [name for name in state_dict.keys() if "lora_" in name] - - if len(lora_params_names) == 0: - return None - - import json - from typing import OrderedDict - - import peft - from safetensors.torch import save_file - - lora_params = OrderedDict() - target_modules = set() - lora_key = None - - for name in lora_params_names: - lora_key = name.replace(".default.weight", ".weight") - target_modules.add(lora_key.split(".")[-3]) - lora_params[lora_key] = state_dict.pop(name) - - lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) - peft_dict = { - "r": lora_rank, - "lora_alpha": 0, # lora_alpha is not set. An error should be raised to inform the user to set it manually. - "target_modules": list(target_modules), - } - peft_config = peft.LoraConfig(**peft_dict).to_dict() - peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None - peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None - peft_config["target_modules"] = list(peft_config["target_modules"]) - - lora_path = os.path.join(self.config.target_dir, "lora_adapter") - os.makedirs(lora_path, exist_ok=True) - with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: - json.dump(peft_config, f, ensure_ascii=False, indent=4) - save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) - - for name in list(state_dict.keys()): - key = name.replace("base_model.model.", "").replace(".base_layer.weight", ".weight").replace(".base_layer.bias", ".bias") - state_dict[key] = state_dict.pop(name) - - return lora_path - - def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - with init_empty_weights(): - model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) - model.to_empty(device="cpu") - model = self.patch_model_generation_config(model) - - lora_path = self.save_lora_adapter(state_dict) - if lora_path: - print(f"Saving lora adapter to {lora_path}") - - print(f"Saving model to {self.config.target_dir}") - model.save_pretrained(self.config.target_dir, state_dict=state_dict) - del state_dict - del model - - processor = hf_processor(self.hf_model_config_path) - tokenizer = hf_tokenizer(self.hf_model_config_path) - if processor is not None: - print(f"Saving processor to {self.config.target_dir}") - processor.save_pretrained(self.config.target_dir) - if tokenizer is not None: - print(f"Saving tokenizer to {self.config.target_dir}") - tokenizer.save_pretrained(self.config.target_dir) - - def upload_to_huggingface(self): - from huggingface_hub import HfApi - - api = HfApi() - api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) - api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") - - @abstractmethod - def merge_and_save(self): - raise NotImplementedError("Subclasses should implement this method") - - -class FSDPModelMerger(BaseModelMerger): - def _get_world_size(self) -> int: - """Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').""" - for filename in os.listdir(self.config.local_dir): - match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) - if match: - return int(match.group(1)) - raise FileNotFoundError(f"Could not determine world size. No file matching 'model_world_size_(\d+)_rank_0.pt' found in {self.config.local_dir}") - - def _load_rank_zero_state_dict(self, world_size: int) -> dict: - return torch.load(Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", map_location="cpu", weights_only=False) - - def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: - """ - Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. - If no DTensor is found, infers a simple FSDP mesh based on world_size. - """ - pivot_key = sorted(list(state_dict.keys()))[0] - weight = state_dict[pivot_key] - - if isinstance(weight, DTensor): - # get sharding info - device_mesh = weight.device_mesh - mesh = device_mesh.mesh - mesh_dim_names = device_mesh.mesh_dim_names - else: - # for non-DTensor - mesh = np.array([world_size], dtype=np.int64) - mesh_dim_names = ("fsdp",) - - return mesh, mesh_dim_names - - def _calculate_shard_configuration(self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]) -> tuple[int, tuple[int, ...]]: - """Calculates the total number of shards and the shape of the device mesh.""" - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" - - if "tp" in mesh_dim_names: - # TODO: "tp" is not supported yet due to the above assert - total_shards = mesh.shape[-1] * mesh.shape[-2] - mesh_shape = (mesh.shape[-2], mesh.shape[-1]) - else: - total_shards = mesh.shape[-1] - mesh_shape = (mesh.shape[-1],) - - return total_shards, mesh_shape - - def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: - """Merges a list of tensors based on their DTensor placement""" - if placement.is_replicate(): - return tensors[0] - elif placement.is_partial(): - raise NotImplementedError("Partial placement is not supported yet") - elif placement.is_shard(): - return torch.cat(tensors, dim=placement.dim).contiguous() - - raise NotImplementedError(f"Unsupported placement: {placement}") - - def _load_and_merge_state_dicts(self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]) -> dict[str, torch.Tensor]: - model_state_dict_lst = [None] * total_shards - - def process_one_shard(rank: int, model_state_dict_lst: list): - model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - model_state_dict_lst[rank] = state_dict - return state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] - for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): - future.result() - - # Merge state dicts from all shards - state_dict = {} - param_placements: dict[str, list] = {} - - for key in set(model_state_dict_lst[0].keys()): - state_dict[key] = [] - for model_state_shard in model_state_dict_lst: - # add tensor shard in order of rank to state_dict[key] - tensor = model_state_shard.pop(key) - if isinstance(tensor, DTensor): - state_dict[key].append(tensor._local_tensor.bfloat16()) - - placements = tuple(tensor.placements) - # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] in ("dp", "ddp"): - placements = placements[1:] - - if key not in param_placements: - param_placements[key] = placements - else: - assert param_placements[key] == placements - else: - state_dict[key].append(tensor.bfloat16()) - - del model_state_dict_lst - - # Merge tensors - for key in sorted(state_dict): - if not isinstance(state_dict[key], list): - print(f"No need to merge key {key}") - continue - if key in param_placements: - # merge shards - placements: tuple[Shard] = param_placements[key] - if len(mesh_shape) == 1: - # 1-D list, FSDP without TP - assert len(placements) == 1 - shards = state_dict[key] - state_dict[key] = self._merge_by_placement(shards, placements[0]) - else: - # 2-D list, FSDP + TP - raise NotImplementedError("FSDP + TP is not supported yet") - else: - state_dict[key] = torch.cat(state_dict[key], dim=0) - - return state_dict - - def merge_and_save(self): - world_size = self._get_world_size() - rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - - mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) - print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - - total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) - print(f"Processing model shards with {total_shards} {mesh_shape} in total") - - merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - - hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) - hf_state_dict = hf_model.state_dict() - del hf_model - - hf_model_keys = set(hf_state_dict.keys()) - collected_keys = set(state_dict.keys()) - - missing_keys = hf_model_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" - - extra_keys = collected_keys - hf_model_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" - - for key in hf_model_keys: - hf_shape = hf_state_dict[key].shape - collected_shape = state_dict[key].shape - assert hf_shape == collected_shape, f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - - hf_dtype = hf_state_dict[key].dtype - collected_dtype = state_dict[key].dtype - assert hf_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - - torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) - - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") - - -class MegatronModelMerger(BaseModelMerger): - def __init__(self, config: ModelMergerConfig): - from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path - - config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir) - super().__init__(config) - - self.params_mapping = { - # megatron core gpt model name, huggingface model name - # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the longer key within the containing relationship is processed first. - "embedding.word_embeddings": "model.embed_tokens", - # attn - "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", - "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", - "self_attention.linear_qkv": "self_attn.qkv_proj", - "self_attention.q_layernorm": "self_attn.q_norm", - "self_attention.k_layernorm": "self_attn.k_norm", - "self_attention.linear_proj": "self_attn.o_proj", - # mla - "self_attention.linear_q_proj": "self_attn.q_proj", - "self_attention.linear_q_down_proj": "self_attn.q_a_proj", - "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", - "self_attention.linear_q_up_proj": "self_attn.q_b_proj", - "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", - "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", - "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", - # mlp - "pre_mlp_layernorm": "post_attention_layernorm", - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", - "mlp.linear_fc1": "mlp.gate_up_proj", - "mlp.linear_fc2": "mlp.down_proj", - # moe - "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", - "mlp.router": "mlp.gate", - "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", - "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", - "linear_fc1": "gate_up_proj", - "linear_fc2": "down_proj", - # output - "final_layernorm": "norm", - "output_layer": "lm_head", - } - - def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: - tp_rank = pp_rank = None - rank_list = sharded_dir.split("_")[2:] - if re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir): - tp_rank = int(rank_list[0]) - pp_rank = int(rank_list[1]) - elif re.match(r"mp_rank_(\d\d)", sharded_dir): - tp_rank = int(rank_list[0]) - pp_rank = 0 - - assert tp_rank is not None and pp_rank is not None, f"Invalid sharded dir {sharded_dir}" - - return tp_rank, pp_rank - - def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]: - """ - Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). - Determines TP and PP sizes from directory names. - """ - tp_size = 0 - pp_size = 0 - sharded_dirs = sorted(os.listdir(model_path)) - for sharded_dir in sharded_dirs: - assert "model.pt" in os.listdir(Path(model_path) / sharded_dir), f"model.pt not found in {sharded_dir}" - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - tp_size = max(tp_size, tp_rank + 1) - pp_size = max(pp_size, pp_rank + 1) - return sharded_dirs, tp_size, pp_size - - def _merge_across_tp(self, key: str, tp_data: list[torch.Tensor], config: PretrainedConfig, tp_size: int, is_value_model: bool = False) -> Union[torch.Tensor, list[torch.Tensor]]: - if "linear_fc1.weight" in key: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in tp_data: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - return [gate, up] - elif "self_attention.linear_qkv." in key and "layer_norm" not in key: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - assert config.num_attention_heads % config.num_key_value_heads == 0 - num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0 - kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - - for infer_param in tp_data: - num_query_groups_per_partition = config.num_key_value_heads // tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - return [q, k, v] - elif "layer_norm" in key or "layernorm" in key or "router" in key or ("output_layer" in key and is_value_model): - return tp_data[0] - else: - dim = 0 - if "linear_fc2.weight" in key or "self_attention.linear_proj" in key: - dim = 1 - return torch.cat(tp_data, dim=dim) - - def _load_state_dicts(self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int) -> list[list[dict]]: - model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)] - - def _process_one_megatron_shard(sharded_dir: str): - model_file_path = Path(model_ckpt_path) / sharded_dir / "model.pt" - state_dict = torch.load(model_file_path, map_location="cpu", weights_only=False) - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - model_state_dict_lst[pp_rank][tp_rank] = state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs] - for future in tqdm(futures, desc=f"Loading {len(sharded_dirs)} Megatron shards", total=len(sharded_dirs)): - future.result() - - return model_state_dict_lst - - def _check_megatron_state_key(self, key: str) -> bool: - """ - Checks if the key is a valid Megatron state key. - - Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. - Shall not use key starts with "model." - """ - if key.startswith("model."): - raise ValueError(f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer' in TransformerLayer.") - - skip_checking_keys = ["embedding.word_embeddings", "output_layer"] - for skip_key in skip_checking_keys: - if skip_key in key: - print(f"skip checking key {key}") - return - - # Exclude extra state keys - if not key.startswith("decoder"): - raise ValueError(f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer.") - - def _merge_state_dicts(self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int) -> dict[str, torch.Tensor]: - state_dict = {} - vpp_size = len(model_state_dict_lst[0][0]) - layers_cum = 0 - - for vpp_rank in range(vpp_size): - for pp_rank in range(pp_size): - layers_handled = 0 - keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() - for key in keys: - if "extra_state" in key: - continue - if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") - continue - - self._check_megatron_state_key(key) - hf_name = self._replace_name(key, self.params_mapping) - assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." - if "model.layers." in hf_name: - local_layer_no = int(hf_name.split(".")[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list = hf_name.split(".") - new_key_list[2] = str(global_layer_no) - hf_name = ".".join(new_key_list) - else: - warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) - - tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] - merged = self._merge_across_tp(key, tp_data, self.model_config, tp_size, self.config.is_value_model) - - if not isinstance(merged, list): - state_dict[hf_name] = merged - elif len(merged) == 3: - # split qkv - for n, d in zip(["q", "k", "v"], merged): - state_dict[hf_name.replace("qkv", n)] = d - elif len(merged) == 2: - # split gate up - state_dict[hf_name.replace("gate_up", "gate")] = merged[0] - state_dict[hf_name.replace("gate_up", "up")] = merged[1] - print(f"converted {key} to {hf_name} with shape {merged.shape if isinstance(merged, torch.Tensor) else [t.shape for t in merged]}") - - layers_cum += layers_handled + 1 # zero based - - return state_dict - - def merge_and_save(self): - from verl.utils.megatron_utils import get_model_checkpoint_path - - model_ckpt_path = get_model_checkpoint_path(self.config.local_dir) - sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path) - print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}") - - model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size) - merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size) - del model_state_dict_lst - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - """ - Compares the merged Megatron state_dict against a reference safetensors model. - Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. - """ - ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") - - for name, loaded_weight in state_dict.items(): - # name = self._replace_name(original_name, self.params_mapping) - if not name or name.endswith(".bias") and name not in ref_state_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embedding and "lm_head.weight" in name: - continue - if name not in ref_state_dict: - raise RuntimeError(f"key: {name} not exist in state_dict") - param = ref_state_dict[name] - assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2) - - def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: - for m_name, v_name in name_mapping.items(): - if m_name not in megatron_name: - continue - - megatron_name = megatron_name.replace("decoder", "model") - param_name = megatron_name.replace(m_name, v_name) - return param_name - - return None # Return None if no mapping found - - -def main(): - parser = argparse.ArgumentParser(description="verl model merger") - subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") - - base_op_parser = argparse.ArgumentParser(add_help=False) - base_op_parser.add_argument("--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model") - base_op_parser.add_argument("--local_dir", type=str, required=True, help="Path to the saved model checkpoints") - base_op_parser.add_argument("--hf_model_path", type=str, default=None, help="(Deprecated) Path to the original Hugging Face model for config.") - base_op_parser.add_argument("--tie-word-embedding", action="store_true", help="Whether to tie word embedding weights (currently only Megatron supported)") - base_op_parser.add_argument("--is-value-model", action="store_true", help="Whether the model is a value model (currently only Megatron supported)") - - merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") - merge_parser.add_argument("--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model") - merge_parser.add_argument("--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model") - merge_parser.add_argument("--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository") - - test_parser = subparsers.add_parser("test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model") - test_parser.add_argument("--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing") - - args = parser.parse_args() - - common_config_args = { - "operation": args.operation, - "backend": args.backend, - "tie_word_embedding": args.tie_word_embedding, - "is_value_model": args.is_value_model, - "local_dir": args.local_dir, - "hf_model_path": args.hf_model_path, - "hf_model_config_path": args.local_dir, - } - - if args.operation == "merge": - config = ModelMergerConfig( - **common_config_args, - target_dir=args.target_dir, - hf_upload_path=args.hf_upload_path, - private=args.private, - test_hf_dir=None, - ) - os.makedirs(config.target_dir, exist_ok=True) - elif args.operation == "test": - config = ModelMergerConfig( - **common_config_args, - test_hf_dir=args.test_hf_dir, - # the following args are not used by test operation - target_dir=None, - hf_upload_path=None, - private=False, - ) - else: - raise NotImplementedError(f"Unknown operation: {args.operation}") - - if config.backend == "fsdp": - merger = FSDPModelMerger(config) - elif config.backend == "megatron": - merger = MegatronModelMerger(config) - else: - raise NotImplementedError(f"Unknown backend: {config.backend}") - - merger.merge_and_save() - - -if __name__ == "__main__": - main() diff --git a/astune/utils/n_gram.py b/astune/utils/n_gram.py deleted file mode 100644 index 6567d8ef..00000000 --- a/astune/utils/n_gram.py +++ /dev/null @@ -1,84 +0,0 @@ -import re - -def preserve_chinese(text): - # 使用正则表达式匹配所有中文字符 - chinese_chars = re.findall(r'[\u4e00-\u9fff]', text) - # 将匹配到的中文字符拼接成一个字符串 - return ''.join(chinese_chars) - - -def get_repetition_penalty_reward(ngram_size: int, max_penalty: float): - """ - Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373. - Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py - - Args: - ngram_size: size of the n-grams - max_penalty: Maximum (negative) penalty for wrong answers - """ - if max_penalty > 0: - raise ValueError(f"max_penalty {max_penalty} should not be positive") - - - def zipngram_chinese(text: str, ngram_size: int): - import jieba - text = preserve_chinese(text) - seg_list = list(jieba.cut(text)) - # print(seg_list) - return zip(*[seg_list[i:] for i in range(ngram_size)]) - - - def repetition_penalty_reward(completions, **kwargs) -> float: - """ - reward function the penalizes repetitions - ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py - - Args: - completions: List of model completions - """ - - contents = [completion[0]["content"] for completion in completions] - rewards = [] - for completion in contents: - if completion == "": - rewards.append(0.0) - continue - - ngrams = set() - total = 0 - for ng in zipngram_chinese(completion, ngram_size): - ngrams.add(ng) - total += 1 - - if total == 0: - rewards.append(0.0) - continue - - # total 是定值(句子分词长度) - # 假设 全部不重复 total = len(ngrams),scaling ~= 0, reward=0 - # 假设 全部重复 len(ngrams) = 1, scaling ~= 1, reward=-1 - scaling = 1 - len(ngrams) / total - reward = scaling * max_penalty - rewards.append(reward) - return rewards - - return repetition_penalty_reward - - -if __name__ == "__main__": - for ngram_size in range(0, 100): - reward_fn = get_repetition_penalty_reward(ngram_size=ngram_size, max_penalty=-1.0) - completions = [ - [ - { - "content": "<|thought_start|>好的,我现在需要处理用户的问题:“那CPO板块未来会涨还是跌呢?”。首先,我需要回顾之前的对话历史,确保回答连贯。用户之前询问了CPO板块的当天表现,我提供了截止到2025年3月10日的收盘数据,并提到台积电计划量产CPO产品的积极信号。然后,用户问到担忧的具体内容,我解释了国际大厂发布CPO交换机可能影响利润的担忧,但也提到中际旭创等公司在积极应对,台积电也在推进技术。\n\n现在用户直接问CPO板块的未来走势。根据平台提供的供给知识,CPO板块受益于光通信技术和AI发展的双重驱动,台积电与英伟达、博通合作量产CPO产品,Marvell也在AI加速器上取得突破。这些都显示行业有技术进步和市场需求增长的潜力。同时,行业市盈率低于近10年80%的时间,估值偏低,具备性价比。虽然主力资金近期净流出,但重要股东持续增持,市场讨论热度提升,情绪趋于积极。\n\n用户的风险等级是高风险,理财认知高,但CPO属于科技赛道,用户当前科技类可能已经超配。需要结合这些信息,提供一个既符合用户特征,又基于现有数据的回复。需要强调行业潜力和当前估值优势,同时提示市场波动性,建议根据个人风险承受能力决策。确保不能出现重复引导,比如再次推荐基金或询问兴趣,因为这已经在前两轮对话中提到过。需要保持专业性和准确性,不编造信息,只基于提供的资料。确保不能出现重复引导,比如再次推荐基金或询问兴趣,因为这已经在前两轮对话中提到过。需要保持专业性和准确性,不编造信息,只基于提供的资料。确保不能出现重复引导,比如再次推荐基金或询问兴趣,因为这已经在前两轮对话中提到过。需要保持专业性和准确性,不编造信息,只基于提供的资料。<|thought_end|>您关注的CPO板块确实展现出了一定的投资潜力。从当前市场情况来看,CPO行业受益于光通信技术和AI发展的双重驱动,展现出良好的投资潜力。近期,台积电宣布与英伟达、博通合作,预计下半年量产CPO产品,反映行业技术进步和市场需求的增长。同时,行业市盈率低于近10年80%的时间,估值偏低,具备较高的性价比。\\n不过,作为科技赛道,CPO板块的波动性相对较高,短期内市场情绪和资金流向可能会影响板块表现。如果您是短期投资者,可能会经历较大的价格波动;如果是长期投资者,可以考虑分批布局,捕捉行业长期增长潜力。\\n您是想更倾向于短期把握CPO板块的波动机会,还是希望通过长期持有来捕捉行业成长?我可以根据您的偏好,为您推荐几只相关基金。", - }, - ], - [ - { - "content": "今天,一则消息在科技圈热传——据外媒报道,在近日向白宫提交的一项政策建议中,OpenAI对DeepSeek进行了十分露骨的攻击。一边称DeepSeek是“另一个华为”,是“国家补贴”“国家控制”的,会受中国指示“操纵其模型以造成损害”,一边强烈呼吁对该机构及类似机构开发的中国AI模型实施禁令。OpenAI的提案并非临时起意。今年1月,特朗普撤销了此前美国AI行政命令《安全、可靠和值得信赖的人工智能开发与使用》,随后签署了一项新的行政命令,宣布“美国的政策是维持并提升全球人工智能领域的主导地位”,并要求在180天内提交一份AI行动计划。无独有偶,美国另一家AI巨头Anthropic也强烈建议政府加强AI领域出口管制。它们显然都是瞄准了这一计划,想要以“技术话语权”换取“政策制定话语权”。事实上,以妖魔化外国竞争对手说事,博得政府支持最大化,是美国科技公司的惯用伎俩。但如此情绪激烈、赤裸裸地迎合美国“零和博弈”的政治逻辑,高调呼吁“抵制”“封禁”某家外国公司,多少还是过于难看。更可笑的是,这些话还出自自诩“开放”的OpenAI之口。OpenAI曾是开源的代名词,GPT-2的全面开放一度被视为行业标杆。但自GPT-3起,这家公司就选择筑起闭源高墙,更通过API订阅制赚得盆满钵满。本以为,DeepSeek掀起的AI行业普惠与效率革命,会让美国AI公司反思自己的路径选择。但如今来看,科技霸权主义对美国政治、美国科技企业的毒害远比我们想象的大得多。“层层封锁,封不住孤勇者。”不管出于什么心态,美国AI公司或许是忘了DeepSeek是怎么“逆袭”的了。近些年,美国政府为了压制中国人工智能的发展,动作不断,先是禁止AI芯片霸主英伟达向中国出口高端AI芯片,后又推动AI大模型霸主OpenAI停止向中国开放。这种情况下,即便对中国AI发展最为乐观的人也不敢想象——硬件性能无法做到最优,那就卷架构、卷工程,最终竟能“柳暗花明又一村”,让美国精心堆砌的人工智能高墙濒临坍塌。OpenAI将DeepSeek比作华为,怕是也忘了,几乎体验过美国所有打压、制裁手段的华为,如今已然是“轻舟已过万重山”。就如华为创始人任正非所说的,无论美国再怎么打压,华为也不恨美国。华为原来就像羊,慢慢吃草,变得很胖;美国是狼,狼一追羊就跑,减肥了就有战斗力。从更大范围看,于封锁中坚定自主创新、国产自研的心气与斗志,一直都是中国科技发展独特的“孤勇叙事”。至于迟迟丢不掉霸权逻辑者,在封闭、狭隘的创新环境里打转转,所谓的技术领先优势又能保持多久呢?最为讽刺的是,在提案中OpenAI还在唱“推广民主AI”“确保AI发展惠及所有人”的高调。试问,打压他国AI技术发展,到底“民主”在哪?“普惠”在哪?更何况,DeepSeek坚定走开源之路,极大促进了全球AI行业的合作交流,正让“AI人人可用”的理想不再遥不可及。“闭源”打压“开源”,还好意思抢占道德高地,不得不说,美国科技精英玩起政治双标那套来,也很溜。管中窥豹,可见一斑。马斯克曾言辞激烈地抨击OpenAI:“封闭”且“贪婪”。而这,形容的何尝不是当下的美国呢?" - } - ] - ] - rewards = reward_fn(completions) - print(ngram_size, rewards) diff --git a/astune/utils/process_dataset.py b/astune/utils/process_dataset.py deleted file mode 100644 index 721d8cc6..00000000 --- a/astune/utils/process_dataset.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import time -import datasets -from datasets import Dataset -from torch.utils.data import Dataset as TorchDataset -from torch.utils.data import RandomSampler, SequentialSampler - -from verl.utils.dataset.rl_dataset import RLHFDataset -from verl.utils.import_utils import load_extern_type -from verl.experimental.dataset.sampler import AbstractSampler - - -from typing import List, Optional, Union, Dict, Any -from transformers.processing_utils import ProcessorMixin -from omegaconf import DictConfig - - -class EnvServiceDataset(RLHFDataset): - """Dataset class that handles environment service data loading and processing.""" - - def __init__(self, - data_files: List[str], - tokenizer, - processor: Optional[ProcessorMixin], - config: DictConfig, - env_config: Optional[DictConfig] = None): - """Initialize the EnvServiceDataset. - - Args: - data_files: List of data file paths - tokenizer: The tokenizer to use - processor: The processor to use for multimodal data - config: Configuration for dataset - env_config: Configuration for environment service - """ - self.config = config - self.env_config = env_config or {} - super().__init__(data_files, tokenizer, config, processor) - - def _read_files_and_tokenize(self): - env_url = self.env_config.env_url - env_type = self.env_config.env_type - dataframes = [] - - from astune.env_service_client.env_client_ng import EnvClient - for parquet_file in self.data_files: - # read parquet files and cache - if 'train(read_from_env_service)' in parquet_file: - split = 'train' - elif 'val(read_from_env_service)' in parquet_file: - split = 'test' - split = 'dev' # or test_normal - else: - raise ValueError(f"Unsupported split: {parquet_file}") - env_service_client = EnvClient(base_url=env_url) - task_id_array = env_service_client.get_env_profile(env_type, split=split) - if len(task_id_array) == 0: - raise ValueError(f"No task_id found for env_type: {env_type}, split: {split}, Please check connection to {env_url}") - data = { - 'data_source': [env_type for task_id in task_id_array], - 'prompt': ['not available' for task_id in task_id_array], - 'reward_model': [{} for task_id in task_id_array], - 'extras': [{'task_id': task_id} for task_id in task_id_array], - } - dataframe = Dataset.from_dict(data) - dataframes.append(dataframe) - - self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) - print(f"dataset len: {len(self.dataframe)}") - self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe) - - -def create_rl_dataset( - data_paths: List[str], - data_config: DictConfig, - tokenizer, - processor: Optional[ProcessorMixin], - is_train: bool = True, - env_config: Optional[DictConfig] = None -) -> TorchDataset: - """Create a dataset. - - Arguments: - data_paths: List of paths to data files. - data_config: The data config. - tokenizer (Tokenizer): The tokenizer. - processor (Processor): The processor. - is_train (bool): Whether this is for training or validation. - env_config: Environment configuration. - - Returns: - dataset (Dataset): The dataset. - """ - - # Check if a custom dataset class is specified in the data configuration - # and if the path to the custom class is provided - if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: - # Dynamically load the custom dataset class - dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) - # Verify that the custom dataset class inherits from torch.utils.data.Dataset - if not issubclass(dataset_cls, TorchDataset): - raise TypeError( - f"The custom dataset class '{data_config.custom_cls.name}' from " - f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" - ) - - elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: - # If a data generation strategy is specified, use the DynamicGenDataset class - from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset - dataset_cls = DynamicGenDataset - print("Using DynamicGenDataset for data generation.") - - else: - # Use EnvServiceDataset - dataset_cls = EnvServiceDataset - print(f"Using dataset class: {dataset_cls.__name__}") - - # Instantiate the dataset using the determined dataset class - print('using', dataset_cls) - dataset = dataset_cls( - data_files=data_paths, - tokenizer=tokenizer, - processor=processor, - config=data_config, - env_config=env_config, - ) - - return dataset - - -def create_rl_sampler( - data_config: DictConfig, - dataset: TorchDataset -) -> Union[RandomSampler, SequentialSampler, AbstractSampler]: - """Create a sampler for the dataset. - - Arguments: - data_config: The data config. - dataset (Dataset): The dataset. - - Returns: - sampler (Sampler): The sampler. - """ - if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: - curriculum_class = load_extern_type( - data_config.sampler.class_path, - data_config.sampler.class_name, - ) - sampler = curriculum_class( - data_source=dataset, - data_config=data_config, - ) - assert isinstance(sampler, AbstractSampler) - assert data_config.get("dataloader_num_workers", 8) == 0, ( - "If using curriculum, num_workers must be 0 to prevent data caching. " - "If the dataloader caches data before the batch is done the " - "curriculum sampler won't have the opportunity to reorder it. " - ) - - # Use a sampler to facilitate checkpoint resumption. - # If shuffling is enabled in the data configuration, create a random sampler. - elif data_config.shuffle: - train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(data_config.get("seed", int(time.time()))) - sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) - else: - # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. - sampler = SequentialSampler(data_source=dataset) - - return sampler diff --git a/astune/utils/process_manager.py b/astune/utils/process_manager.py deleted file mode 100644 index 3f619074..00000000 --- a/astune/utils/process_manager.py +++ /dev/null @@ -1,85 +0,0 @@ -import subprocess -import argparse -import shutil -import time -import sys -import os -import signal -import shlex - - -def _replace_placeholder_in_config(config_obj, placeholder: str, replacement: str): - """Recursively replace placeholder in all string values within dict/list structures. - - - Traverses dicts and lists deeply - - Replaces all occurrences of `placeholder` inside string values - - Leaves non-string scalars untouched - """ - - def _walk(node): - if isinstance(node, dict): - return {k: _walk(v) for k, v in node.items()} - if isinstance(node, list): - return [_walk(v) for v in node] - if isinstance(node, str): - return node.replace(placeholder, replacement) - return node - - return _walk(config_obj) - - -def kill_process_with_keyword(keyword: str, exclude_substrings=None, grace_seconds: float = 1.0): - """Use bash pipelines to kill processes matching keyword quickly. - - - Filters out processes containing any exclude_substrings - - Excludes current launcher process - - Sends TERM once to all PIDs, then KILL once to all PIDs after a short grace period - - Returns list of PIDs targeted - """ - if exclude_substrings is None: - exclude_substrings = ["vscode"] - - self_pid = os.getpid() - - # Build a fast PID collector using pgrep if available; fallback to ps/grep - # We prefer pgrep -af to filter by full command and then extract PID (column 1) - exclude_filters = " ".join([f"| grep -v -F {shlex.quote(s)}" for s in exclude_substrings]) - pid_list_cmd = ( - f"(pgrep -af -- {shlex.quote(keyword)} 2>/dev/null || true) " - f"{exclude_filters} | awk '{{print $1}}' | grep -v -x {self_pid} || true" - ) - - try: - res = subprocess.run(["bash", "-lc", pid_list_cmd], capture_output=True, text=True, check=False) - pids = [pid for pid in res.stdout.split() if pid.isdigit()] - except Exception as e: - print(f"Failed to list PIDs via bash: {e}") - pids = [] - - # Fallback to ps/grep if pgrep path produced nothing (e.g., no pgrep installed) - if not pids: - ps_pid_cmd = ( - f"ps -eo pid,command -ww | grep -F -- {shlex.quote(keyword)} | grep -v grep " - f"{exclude_filters} | awk '{{print $1}}' | grep -v -x {self_pid} || true" - ) - try: - res2 = subprocess.run(["bash", "-lc", ps_pid_cmd], capture_output=True, text=True, check=False) - pids = [pid for pid in res2.stdout.split() if pid.isdigit()] - except Exception as e: - print(f"Failed to list PIDs via ps/grep: {e}") - pids = [] - - if not pids: - return [] - - pid_args = " ".join(pids) - try: - # Send TERM to all in one call - subprocess.run(["bash", "-lc", f"kill -TERM -- {pid_args} 2>/dev/null || true"], check=False) - time.sleep(grace_seconds) - # Escalate with KILL once; ignore failures for already-exited PIDs - subprocess.run(["bash", "-lc", f"kill -KILL -- {pid_args} 2>/dev/null || true"], check=False) - except Exception as e: - print(f"Error issuing kill commands: {e}") - - return [int(p) for p in pids] \ No newline at end of file diff --git a/astune/utils/pty.py b/astune/utils/pty.py deleted file mode 100644 index b6dbcbf2..00000000 --- a/astune/utils/pty.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -import pty - -def run_command_with_pty(cmd, working_dir, env_dict): - """ - 使用伪终端运行命令,并将输出写入日志文件。 - - 参数: - cmd (list): 要运行的命令(如 ["ls", "-l"])。 - working_dir (str): 工作目录。 - env_dict (dict): 环境变量字典。 - """ - # 保存原始环境变量 - original_env = os.environ.copy() - original_dir = os.getcwd() - - try: - # 切换到指定工作目录 - os.chdir(working_dir) - - # 更新环境变量 - for key, value in env_dict.items(): - os.environ[key] = value - - # # 打开日志文件以追加模式写入 - # with open(log_file, 'a') as log_f: - - # 定义主设备读取回调函数 - def master_read(fd): - try: - # 从主设备读取数据 - data = os.read(fd, 1024) - except OSError: - return b"" - - if data: - # 将数据写入日志文件 - # log_f.write(data.decode()) - # log_f.flush() - # 同时打印到标准输出(可选) - print(data.decode(), end="") - return data - - # 定义标准输入读取回调函数 - def stdin_read(fd): - # 如果不需要从标准输入读取数据,直接返回空字节 - return b"" - - # 使用 pty.spawn 分配伪终端并运行命令 - pty.spawn(cmd, master_read, stdin_read) - - finally: - # 恢复原始工作目录 - os.chdir(original_dir) - - # 恢复原始环境变量 - os.environ.clear() - os.environ.update(original_env) - -import base64 - -# 将字符串转换为 Base64 -def string_to_base64(s): - # 首先将字符串编码为字节 - s_bytes = s.encode('utf-8') - # 将字节转换为 base64 - base64_bytes = base64.b64encode(s_bytes) - # 将 base64 字节转换回字符串 - base64_string = base64_bytes.decode('utf-8') - return base64_string - -# 将 Base64 转换回字符串 -def base64_to_string(b): - # 将 base64 字符串转换为字节 - base64_bytes = b.encode('utf-8') - # 解码 base64 字节 - message_bytes = base64.b64decode(base64_bytes) - # 将字节转换回字符串 - message = message_bytes.decode('utf-8') - return message - -def pty_wrapper( - cmd: list[str], - dir: str, - env_dict: dict = {}, -): - run_command_with_pty(cmd, working_dir=dir, env_dict=env_dict) - -def pty_wrapper_final(human_cmd, dir, env_dict): - print("[pty]: ", human_cmd) - pty_wrapper(["/bin/bash", "-c", human_cmd], dir, env_dict) - - -if __name__ == "__main__": - import argparse - import json - - parser = argparse.ArgumentParser(description="Run a shell command in a PTY with logging and custom env.") - parser.add_argument("--human-cmd", type=str, help="Shell command to run (as a string)") - parser.add_argument("--dir", type=str, default=".", help="Working directory") - parser.add_argument("--env", type=str, default="{}", help="Environment variables as JSON string, e.g. '{\"KEY\":\"VAL\"}'") - - args = parser.parse_args() - - try: - env_dict = json.loads(args.env) - if not isinstance(env_dict, dict): - raise ValueError - except Exception: - print("--env must be a valid JSON object string, e.g. '{\"KEY\":\"VAL\"}'. But get:", args.env) - exit(1) - - pty_wrapper_final(base64_to_string(args.human_cmd), args.dir, env_dict) - diff --git a/astune/utils/utils.py b/astune/utils/utils.py deleted file mode 100644 index 1e67afbe..00000000 --- a/astune/utils/utils.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Any, List, Dict -import asyncio - -# apply chat_template to a message, and then convert back to message -def convert_tool_to_user_message(tool_message, tokenizer, format="qwen"): - assert format == "qwen" - - if tool_message["role"] == "user": - return tool_message - elif tool_message["role"] == "tool" and len(tool_message["tool_calls"])>0: - assert len(tool_message["tool_calls"])==1 - return { - "role": "user", - "content": str(tool_message["tool_calls"][0]['result']) - } - - -def run_async_coro__no_matter_what_the_fuck(coro): - try: - asyncio.get_running_loop() - in_loop = True - except RuntimeError: - in_loop = False - if not in_loop: - final_res = asyncio.run(coro) - else: - import threading - _res_holder = {} - _exc_holder = {} - def _run(): - try: - _res_holder["res"] = asyncio.run(coro) - except Exception as _e: - _exc_holder["exc"] = _e - - _t = threading.Thread(target=_run, daemon=True) - _t.start() - _t.join() - if "exc" in _exc_holder: - raise _exc_holder["exc"] - final_res = _res_holder["res"] - return final_res \ No newline at end of file diff --git a/boot_sms.py b/boot_sms.py deleted file mode 100644 index d096c3bf..00000000 --- a/boot_sms.py +++ /dev/null @@ -1,9 +0,0 @@ -def send_train_message(message: str): - import requests, os # 发送短信汇报训练进程 - assert len(message) < 64, f"Message too long: {(message)}" - try: requests.post(json={"phone_numbers": "18810508767", "server_code": "DLC", "error": message, "error_level": "无"}, url=os.getenv("ALIYUN_SMS_SERVICE", "http://cloud-6.agent-matrix.com:12337/send-sms/"), headers={"Content-Type": "application/json"}) - except Exception as e: print(f"Failed to send sms: {e}") - print('sms send') - - -send_train_message("容器启动") \ No newline at end of file diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 00000000..eac09687 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,2 @@ +# MyST build outputs +_build diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 00000000..fe7ad2ea --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1,128 @@ +# Book settings +# Learn more at https://jupyterbook.org/customize/config.html + +project: "AgentJet" +title: "AgentScope
Tuner" +author: Alibaba Tongyi Lab +logo: logo.png +copyright: "2025, Tongyi Lab, Alibaba Inc." +only_build_toc_files: true + +# Force re-execution of notebooks on each build. +# See https://jupyterbook.org/content/execute.html +execute: + execute_notebooks: off + +parse: + myst_enable_extensions: + - colon_fence + +# Define the name of the latex output file for PDF builds +latex: + latex_documents: + targetname: book.tex + +# Add a bibtex file so that we can create citations +bibtex_bibfiles: + - references.bib + +# Sphinx settings +sphinx: + extra_extensions: + - sphinx.ext.autodoc + - sphinx.ext.viewcode + - sphinx.ext.napoleon + - sphinx.ext.intersphinx + - sphinx.ext.autosummary + config: + # API Documentation Configuration + autosummary_generate: True + autosummary_imported_members: True + + # Autodoc Configuration + autodoc_typehints: 'description' + autodoc_member_order: 'bysource' + autodoc_default_options: + members: True + member-order: 'bysource' + special-members: '__init__' + undoc-members: True + exclude-members: '__weakref__' + + # Napoleon Configuration + napoleon_google_docstring: True + napoleon_numpy_docstring: True + napoleon_include_init_with_doc: False + napoleon_include_private_with_doc: False + napoleon_include_special_with_doc: True + napoleon_use_admonition_for_examples: False + napoleon_use_admonition_for_notes: False + napoleon_use_admonition_for_references: False + napoleon_use_ivar: False + napoleon_use_param: True + napoleon_use_rtype: True + + # Intersphinx Configuration + intersphinx_mapping: + python: ['https://docs.python.org/3', null] + numpy: ['https://numpy.org/doc/stable/', null] + + # Theme Configuration + html_theme: furo + pygments_style: "friendly" + html_favicon: logo.png + html_show_sphinx: false + html_last_updated_fmt: "%Y-%m-%d" + html_copy_source: false + html_show_sourcelink: false + templates_path: ["./_templates"] + html_static_path: + - "_static" + use_multitoc_numbering: false + html_js_files: + - language.js + html_sidebars: + "**": + - "sidebar/scroll-start.html" + - "sidebar/brand.html" + - "language-switch.html" + - "version-switch.html" + - "sidebar/search.html" + - "sidebar/navigation.html" + - "sidebar/ethical-ads.html" + - "sidebar/scroll-end.html" + html_theme_options: + top_of_page_buttons: ["view"] + sidebar_hide_name: false + source_repository: "https://github.com/modelscope/AgentJet" + source_branch: "main" + source_directory: "docs/" + footer_icons: + - name: GitHub + url: "https://github.com/modelscope/AgentJet" + html: | + + + + class: "" + - name: Discord + url: "https://discord.gg/eYMpfnkG8h" + html: | + + + + class: "" + - name: DingTalk + url: "https://qr.dingtalk.com/action/joingroup?code=v1,k1,OmDlBXpjW+I2vWjKDsjvI9dhcXjGZi3bQiojOq3dlDw=&_dt_no_comment=1&origin=11" + html: | + + + + class: "" + light_css_variables: + color-brand-primary: "#2196f3" + color-brand-content: "#2196f3" + color-admonition-background: "#f8f9fa" + dark_css_variables: + color-brand-primary: "#64b5f6" + color-brand-content: "#64b5f6" diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 00000000..80442e9f --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,205 @@ +.language-switch { + padding: 1rem 0; + border-bottom: 1px solid var(--color-background-border); + margin-bottom: 1rem; +} + +.language-selector { + position: relative; + font-size: var(--font-size--small); +} + +.language-button { + width: 95%; + padding: 0.5rem 0.75rem; + background: var(--color-background-secondary); + border: 1px solid var(--color-background-border); + border-radius: 0.375rem; + cursor: pointer; + display: flex; + justify-content: space-between; + align-items: center; + color: var(--color-foreground-primary); + font-family: var(--font-stack); + font-size: inherit; + font-weight: 400; + transition: all 0.15s ease-out; + box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.05); + margin: 0 auto; +} + +.language-button:hover { + background: var(--color-background-hover); + border-color: var(--color-brand-primary); + box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px 0 rgb(0 0 0 / 0.06); +} + +.language-button:focus { + outline: 2px solid var(--color-brand-primary); + outline-offset: 2px; +} + +.language-menu { + display: none; + position: absolute; + top: calc(100% + 0.25rem); + left: 0; + right: 0; + background: var(--color-background-primary); + border: 1px solid var(--color-background-border); + border-radius: 0.375rem; + box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -2px rgb(0 0 0 / 0.05); + z-index: 1000; + overflow: hidden; + animation: fadeInUp 0.2s ease-out; +} + +@keyframes fadeInUp { + from { + opacity: 0; + transform: translateY(-0.5rem); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +.language-option { + display: block; + padding: 0.5rem 0.75rem; + text-decoration: none; + color: var(--color-foreground-primary); + font-family: var(--font-stack); + font-size: inherit; + font-weight: 400; + transition: all 0.15s ease-out; + border-bottom: 1px solid var(--color-background-border); + position: relative; +} + +.language-option:last-child { + border-bottom: none; +} + +.language-option:hover { + background: var(--color-background-hover); + text-decoration: none; + color: var(--color-foreground-primary); +} + +.language-option:focus { + background: var(--color-background-hover); + outline: none; + text-decoration: none; +} + +.language-option.active { + background: var(--color-brand-primary); + color: var(--color-brand-content); + font-weight: 500; +} + +.language-option.active:hover { + background: var(--color-brand-primary); + color: var(--color-brand-content); +} + +.language-button::after { + content: ""; + width: 0; + height: 0; + border-left: 0.25rem solid transparent; + border-right: 0.25rem solid transparent; + border-top: 0.25rem solid currentColor; + transition: transform 0.15s ease-out; + opacity: 0.7; +} + +.language-selector.open .language-button::after { + transform: rotate(180deg); +} + +.language-selector.open .language-menu { + display: block; +} + +.toctree-l1[data-lang="en"] { + display: var(--lang-en-display, block); +} + +.toctree-l1[data-lang="zh"] { + display: var(--lang-zh-display, block); +} + +.toctree-l1[data-lang="both"] { + display: block; +} + +body[data-current-lang="en"] { + --lang-en-display: block; + --lang-zh-display: none; +} + +body[data-current-lang="zh"] { + --lang-en-display: none; + --lang-zh-display: block; +} + +@media (max-width: 768px) { + .language-switch { + padding: 0.75rem 0; + margin-bottom: 0.75rem; + } + + .language-button { + padding: 0.625rem 0.75rem; + } + + .language-option { + padding: 0.625rem 0.75rem; + } +} + +@media (prefers-contrast: high) { + .language-button { + border-width: 2px; + } + + .language-menu { + border-width: 2px; + } +} + +@media (prefers-reduced-motion: reduce) { + .language-button, + .language-option, + .language-button::after { + transition: none; + } + + .language-menu { + animation: none; + } +} + +/* Dark mode support */ +@media (prefers-color-scheme: dark) { + .language-button, + .language-option { + color: var(--color-foreground-secondary); + } +} + +.sidebar-logo-container .sidebar-logo { + max-height: 170px; + width: auto; + display: block; +} + +.sidebar-brand-text { + display: flex; + flex-direction: column; + align-items: center; + line-height: 0.6; +} diff --git a/docs/_static/language.js b/docs/_static/language.js new file mode 100644 index 00000000..0d089591 --- /dev/null +++ b/docs/_static/language.js @@ -0,0 +1,32 @@ +(function () { + "use strict"; + + function getCurrentPageLanguage() { + const path = window.location.pathname; + if (path.includes("/zh/")) { + return "zh"; + } + return "en"; + } + + function autoSetLanguage() { + const currentLang = getCurrentPageLanguage(); + const savedLang = localStorage.getItem("preferred-language"); + + if (currentLang !== savedLang) { + localStorage.setItem("preferred-language", currentLang); + } + + setTimeout(() => { + if (window.switchLanguage) { + window.switchLanguage(currentLang, true); + } + }, 5); + } + + if (document.readyState === "loading") { + document.addEventListener("DOMContentLoaded", autoSetLanguage); + } else { + autoSetLanguage(); + } +})(); diff --git a/docs/_templates/language-switch.html b/docs/_templates/language-switch.html new file mode 100644 index 00000000..e2ee1fb8 --- /dev/null +++ b/docs/_templates/language-switch.html @@ -0,0 +1,103 @@ +
+
+ +
+ English + 中文 +
+
+ +
+ +
+
+
+
+ + diff --git a/docs/_templates/version-switch.html b/docs/_templates/version-switch.html new file mode 100644 index 00000000..e3e5821b --- /dev/null +++ b/docs/_templates/version-switch.html @@ -0,0 +1,161 @@ + + + diff --git a/docs/_toc.yml b/docs/_toc.yml new file mode 100644 index 00000000..7eb76610 --- /dev/null +++ b/docs/_toc.yml @@ -0,0 +1,65 @@ +format: jb-book +root: en/introduction.md +parts: + # --- 英文部分 --- + - caption: Tutorial + chapters: + - file: en/intro.md + - file: en/installation.md + - file: en/quickstart.md + - file: en/tune_your_first_agent.md + + - caption: Example + chapters: + - file: en/example_math_agent.md + - file: en/example_app_world.md + - file: en/example_werewolves.md + - file: en/example_learning_to_ask.md + - file: en/example_frozenlake.md + - file: en/example_countdown.md + + - caption: Component + chapters: + - file: en/workflow.md + - file: en/data_pipeline.md + - file: en/task_judger.md + + - caption: Deep Dive + chapters: + - file: en/configuration.md + - file: en/visualization.md + - file: en/beast_logger.md + - file: en/data_generation.md + - file: en/example_tracing_feedback_loop.md + + + # --- 中文部分 --- + - caption: 教程 + chapters: + - file: zh/intro.md + - file: zh/installation.md + - file: zh/quickstart.md + - file: zh/tune_your_first_agent.md + + - caption: 示例 + chapters: + - file: zh/example_math_agent.md + - file: zh/example_app_world.md + - file: zh/example_werewolves.md + - file: zh/example_learning_to_ask.md + - file: zh/example_frozenlake.md + - file: zh/example_countdown.md + + - caption: 组件 + chapters: + - file: zh/workflow.md + - file: zh/data_pipeline.md + - file: zh/task_judger.md + + - caption: 深入探索 + chapters: + - file: zh/configuration.md + - file: zh/visualization.md + - file: zh/beast_logger.md + - file: zh/data_generation.md + - file: zh/example_tracing_feedback_loop.md diff --git a/docs/agentjet.jpg b/docs/agentjet.jpg new file mode 100644 index 00000000..6d51f4a9 Binary files /dev/null and b/docs/agentjet.jpg differ diff --git a/docs/en/agent_framework_support.md b/docs/en/agent_framework_support.md new file mode 100644 index 00000000..8f9e0135 --- /dev/null +++ b/docs/en/agent_framework_support.md @@ -0,0 +1,54 @@ +# Agent Framework Support + +AgentJet currently supports (and has been tested with) the following agentic frameworks. + + diff --git a/docs/en/beast_logger.md b/docs/en/beast_logger.md new file mode 100644 index 00000000..322306bd --- /dev/null +++ b/docs/en/beast_logger.md @@ -0,0 +1,47 @@ +# Beast-Logger Usage + +Beast-logger is a logging kit built for LLM systems, +providing reliable high-resolution token-level LLM activity +that is unprecedented in any other projects. + +Here is how to use beast-logger in AgentJet. + +## Usage in AgentJet + +1. Start training or debugging with AgentJet launcher. + +2. Wait until the first batch is completed. + +3. Locate log files. By default, they will be placed at `saved_experiments/${experiment_name}`. For example: +`saved_experiments/benchmark_frozenlake_20251223_2305` + +4. Run `beast_logger_go` command in the VSCode terminal (or any other software with port-forwarding ability) to start the web log-viewer. Click `http://127.0.0.1:8181` to open it (VSCode will automatically forward this port from server -> your local computer) + +
+ image +
+ +5. Fill the **ABSOLUTE** path of the log files and click `submit`. + + > Hint: absolute path is recommended. + > + > However, you can also use relative path, if `beast_logger_go` command is launched at same working dir. + + > Warning: Beast-logger recursively scans this path, + > + > thus, where possible, selects the innermost directory containing the fewest files to read logs faster. + +
+ image +
+ + +6. Choose entry to display + + - Yellow tokens: tokens that are excluded from loss computation. + - Blue tokens: tokens that participant loss computation. + - Hovor your mouse on one of the tokens: show the logprob value of that token. + +
+image +
diff --git a/docs/en/component.md b/docs/en/component.md new file mode 100644 index 00000000..f7347dfb --- /dev/null +++ b/docs/en/component.md @@ -0,0 +1 @@ +# Component diff --git a/docs/en/configuration.md b/docs/en/configuration.md new file mode 100644 index 00000000..16c4db5e --- /dev/null +++ b/docs/en/configuration.md @@ -0,0 +1,268 @@ +# Configuration + +This page provides a detailed description of the configuration files for AgentJet. + + + +## Overview + +AgentJet uses YAML-format configuration files to set up data, algorithms, rewards, logging, and other runtime behaviors. + +!!! info "Default Configuration" + The default config is located at `ajet/default_config/ajet_default.yaml`. + +At a high level, a typical config contains a single root section `ajet`, which is divided into several logical parts: + +
+ +-  **Basic Metadata** — Project name, experiment name, experiment directory, and backbone selection + - `project_name`, `experiment_name`, `experiment_dir` + - `backbone`: Select training backend (`debug`, `trinity`, or `verl`) + +-  **Data & Reward** — How to load data and evaluate agents + - `task_reader`: Load training/validation samples + - `task_judge`: Evaluate agents and compute rewards + - `data`: Prompt/response length and batch sizes + +-  **Model & Rollout** — Model configuration and agent interaction + - `model`: Base model to train + - `rollout`: Agent-environment interaction settings + - `context_tracker`: Conversation/history management + +
+ + + +## Model Configuration + +### Specifying the Model + +```yaml title="config.yaml" +ajet: + model: + path: path/to/model +``` + +| Source Type | Example | +|-------------|---------| +| **Local file** | `/mnt/data/models/Qwen2.5-14B-Instruct` | +| **HuggingFace repo** | `Qwen/Qwen2.5-14B-Instruct` (auto-downloaded) | + +### Environment Variables for LLM-as-Judge + +If using LLM-as-a-Judge, configure these environment variables: + +```bash +# DashScope API key for remote LLM calling +export DASHSCOPE_API_KEY='sk-xxxxxx|sk-yyyyyy' +export DASHSCOPE_API_KEY_BACKUP='sk-zzzzzz' +``` + + +## Data Configuration + +### Task Reader + +`task_reader` defines how to read training and validation data. + +=== "EnvService" + + ```yaml + ajet: + task_reader: + type: env_service + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code + training_split: train + validation_split: dev + ``` + +=== "JSONL File" + + ```yaml + ajet: + task_reader: + type: jsonl_dataset_file + jsonl_dataset_file: + training: + file_path: "data/train.jsonl" + validation: + file_path: "data/val.jsonl" + ``` + +=== "HuggingFace" + + ```yaml + ajet: + task_reader: + type: huggingface_dat_repo + huggingface_dat_repo: + dataset_path: "gsm8k" + training_split: "train" + validation_split: "validation" + ``` + +### Task Judge + +`task_judge` evaluates agent performance and calculates rewards. + +```yaml title="config.yaml" +ajet: + task_judge: + judge_type: customized_protocol # or 'rubrics_auto_grader' + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + alien_llm_model: qwen3-235b-a22b-instruct-2507 + alien_llm_response_length: 512 +``` + +| Option | Description | +|--------|-------------| +| `customized_protocol` | Use a custom Python class for scoring | +| `rubrics_auto_grader` | Use LLM-based automatic grading | + + +## Training Configuration + +### Backend Selection + +AgentJet supports three training backends: + +| Backend | Description | +|---------|-------------| +| **trinity** | Default. Flexible and scalable framework for RL fine-tuning | +| **verl** | Volcano Engine reinforcement learning for LLMs | +| **debug** | Allows breakpoint debugging in IDEs | + +```yaml title="config.yaml" +ajet: + backbone: trinity # debug, trinity, or verl +``` + +### Rollout Configuration + +Controls agent behavior during environment interaction: + +```yaml title="config.yaml" +ajet: + rollout: + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + max_env_worker: 128 + temperature: 0.9 + top_p: 1.0 + name: vllm + n_vllm_engine: 2 + num_repeat: 4 +``` + +| Parameter | Description | +|-----------|-------------| +| `user_workflow` | Path to workflow implementation class | +| `temperature` / `top_p` | Sampling parameters | +| `name` | Inference engine (e.g., `vllm`) | +| `n_vllm_engine` | Number of vLLM engines (Trinity only) | + +### Common Training Parameters + +```yaml title="config.yaml" +ajet: + trainer_common: + total_epochs: 50 + save_freq: 20 + test_freq: 20 + val_before_train: False + val_pass_n: 4 + nnodes: 1 + n_gpus_per_node: 8 + mini_batch_num: 1 + fsdp_config: + param_offload: True + optimizer_offload: True +``` + +| Parameter | Description | +|-----------|-------------| +| `total_epochs` | Total training epochs | +| `save_freq` | Checkpoint save frequency (steps) | +| `test_freq` | Validation frequency (steps) | +| `nnodes` / `n_gpus_per_node` | Distributed training setup | +| `fsdp_config` | FSDP memory optimization | + +### Optimization Algorithms + +```yaml title="config.yaml" +ajet: + trainer_common: + algorithm: + adv_estimator: grpo + use_kl_in_reward: False + optim: + lr: 1e-6 + use_kl_loss: True + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl +``` + +| Parameter | Description | +|-----------|-------------| +| `adv_estimator` | Advantage estimator (e.g., `grpo`) | +| `lr` | Learning rate | +| `use_kl_loss` | Include KL divergence in loss | +| `kl_loss_coef` | KL loss coefficient | + + +## Debug Mode + +When `backbone: debug`, additional settings are available: + +```yaml title="config.yaml" +ajet: + debug: + debug_max_parallel: 16 + debug_first_n_tasks: 2 + debug_vllm_port: 18000 + debug_vllm_seed: 12345 + debug_tensor_parallel_size: 4 +``` + +!!! tip "Debug Mode Use Cases" + - **Limiting tasks**: Quickly verify the pipeline on a few tasks + - **Fixing randomness**: `debug_vllm_seed` helps reproduce issues + - **Reduced parallelism**: Easier to debug with smaller concurrency + + +## Logging & Monitoring + +### Logger Selection + +```yaml title="config.yaml" +ajet: + trainer_common: + logger: swanlab # console, wandb, or swanlab +``` + +| Logger | Description | +|--------|-------------| +| `console` | Standard output for quick progress checking | +| `wandb` | Weights & Biases experiment tracking | +| `swanlab` | SwanLab logging | + +### Output Structure + +All experiment outputs are saved in `./launcher_record/{experiment_name}`: + +| Directory | Contents | +|-----------|----------| +| **Logs** | Logs and error messages | +| **Metrics** | Training metrics (depends on logger) | +| **Checkpoint** | Model checkpoints | + + + +## Next Steps + + diff --git a/docs/en/data_generation.md b/docs/en/data_generation.md new file mode 100644 index 00000000..23bca456 --- /dev/null +++ b/docs/en/data_generation.md @@ -0,0 +1,187 @@ +# Data Generation + +## Introduction +`Data Generation Reader` provides an intelligent data generation method designed to simplify the creation of high-quality training data. The method is flexible and efficient, capable of generating domain-specific tasks based on few-shot examples and optional documents. + +## Method +`Data Generation Reader` employs a two-stage task generation process: + +### Stage 1 (Optional): Document-based Data Generation +This stage is optional. `Document-based Data Generation` generates knowledge-based tasks based on the provided documents. Users can provide one or more documents (supporting formats like PDF, Word, TXT, etc.): + +```plain +According to the Anti-Money Laundering and Counter-Terrorist Financing Ordinance and related Guideline, banks are required to identify and take reasonable measures to verify the identity of the beneficial owner of corporate customers so that the bank is ... +``` + +The generator reads the document content and guides the LLM to batch-generate tasks related to the document content: + +```json +[ + { + "main_query": "What are the key requirements of Customer Due Diligence in AML procedures?", + "related_doc": "Customer Due Diligence measures should include: (a) identifying the customer and verifying the customer's identity..." + }, + { + "main_query": "How should financial institutions handle Suspicious Transaction Reports?", + "related_doc": "When someone knows or suspects that any property represents the proceeds of an indictable offense..." + } + ... +] +``` + +If documents are provided for data generation, the data generated in this stage will be added to the validation task set for the subsequent training process. + +### Stage 2: Few-shot Data Generation +This stage generates the final training tasks. `Few-shot Data Generation` combines a few user-provided tasks with the knowledge-based tasks generated in the first stage, and use the documents as references to generate training tasks. First, the user needs to provide a few task examples: + +```json +{"main_query": "Can banks ask corporate customers to provide information of its ownership?", "answer": "According to the Anti-Money Laundering and ..."} +{"main_query": "Can a bank close my account?", "answer": "Either a customer or a bank may close an account at any time subject to any specific terms and ..."} +... +``` + +These examples will be merged with the tasks generated in the first stage to form an example task set. The generator will sample from this set to be used as few-shot demonstrations, and combined with relevant documents, guide the LLM to batch-generate training tasks: + +```json +[ + { + "main_query": "Are financial institutions required to verify the source of funds for corporate clients during account opening?" + }, + { + "main_query": "What are the requirements for banks to verify customer identities under anti-money laundering regulations?" + } + ... +] +``` + +## Quick Start +`Data Generation Reader` can load a few user-provided tasks and optional documents (in various formats such as PDF, Word, and TXT) from a local path, then generates tasks and loads them as training tasks. + +### Step 1: Prepare data +Provide a few example tasks: + +```json +{"main_query": "What is the capital of France?", "answer": "..."} +{"main_query": "How to cook pasta?", "answer": "..."} +``` + +(Optional) Provide documents and place them in the specified directory: + +```bash +mkdir -p dataset/document +cp your-document.pdf dataset/document/ +``` + +### Step 2: Generate Training Tasks +#### Method 1: Integrate Data Generation into the Training Pipeline +Copy and modify the key configuration parameters in `ajet/default_config/ajet_default.yaml`, and set `ajet.task_reader.type` to `data_generation` to enable this reader. + +```yaml +ajet: + task_reader: + type: data_generation + # when `type == data_generation` + data_generation: + # Document reader configuration + document_reader: + document_path: + - 'dataset/document/your-document1.pdf' + - 'dataset/document/your-document2.pdf' + languages: + - eng + # Task reader (for existing tasks) + query_reader: + type: jsonl_dataset_file + jsonl_dataset_file: + training: + file_path: 'dataset/jsonl/your-queries.jsonl' + # Number of tasks to generate + task_num: 10 + # LLM config + llm_model: qwen-long + llm_response_length: 8192 + num_workers: 32 + sampling_params: + temperature: 0 + # Task filtering config + deduplication_filter: + enabled: true + params: + similarity_threshold: 0.8 + db_path: ./.similarity_db + model: text-embedding-v4 + api_key: null # load from the env + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 +``` + +#### Method 2: Run the Generation Script +```python +from ajet.data_generator.config import * +from ajet.task_reader.data_generator_reader import DataGeneratorTaskReader + +def run(): + config = TaskReaderConfig( + data_generation=DataGenerationConfig( + document_reader=DocumentReaderConfig( + document_path=['dataset/document/your-document1.pdf', 'dataset/document/your-document2.pdf'], + languages=["eng"], + chunk_size=5120, + split_by="sentence", + ), + query_reader=QueryReaderConfig( + type="jsonl_dataset_file", + jsonl_dataset_file=DatasetFileConfig( + training=TrainingDatasetConfig(file_path='dataset/jsonl/your-queries.jsonl') + ), + ), + task_num=50, + llm_model="qwen-long", + num_workers=16, + sampling_params=SamplingParamsConfig(temperature=0.0), + deduplication_filter=DeduplicationFilterConfig( + enabled=True, + params=DeduplicationFilterParamsConfig( + similarity_threshold=0.8, + model="text-embedding-v4", + ), + ), + ) + ) + reader = DataGeneratorTaskReader(reader_config=config) + +run() +``` + +## **Generated Task Examples** +Based on user-provided documents (optional) and a few task examples, the `Data Generation Reader` can batch-generate training tasks: + +```json +[ + { + "main_query": "Are financial institutions required to verify the source of funds for corporate clients during account opening?" + }, + { + "main_query": "What are the requirements for banks to verify customer identities under anti-money laundering regulations?" + } + ... +] +``` + +## Detailed Config Options +| Parameter Path | Type | Default | Required | Description | +| --- | --- | --- | --- | --- | +| `document_reader.document_path` | list[str] | - | No | List of document file paths. Supports PDF, Word, TXT, and more. | +| `document_reader.languages` | list[str] | `['eng']` | No | List of document languages for OCR and text parsing, e.g., `eng` (English), `chs` (Simplified Chinese). | +| `query_reader.type` | str | `jsonl_dataset_file` | Yes | Reader type. Options: `jsonl_dataset_file`, `env_service`, `huggingface_dat_repo`. | +| `query_reader.jsonl_dataset_file.training.file_path` | str | - | Yes | Path to the training tasks JSONL file (when `type: jsonl_dataset_file`). | +| `task_num` | int | `10` | Yes | Number of tasks to generate. The actual number may be reduced by filtering. | +| `llm_model` | str | `qwen-long` | Yes | LLM model name used for task generation. | +| `llm_response_length` | int | `8192` | No | Maximum number of tokens in the LLM response. | +| `num_workers` | int | `32` | No | Number of parallel worker threads for speeding up task generation. | +| `sampling_params.temperature` | float | `0` | No | Sampling temperature. `0` means greedy decoding (deterministic output); higher values make outputs more random. | +| `deduplication_filter.enabled` | bool | `true` | No | Whether to enable the deduplication filter. | +| `deduplication_filter.params.similarity_threshold` | float | `0.8` | Yes | Similarity threshold (0–1). Tasks above this threshold will be filtered out. | +| `deduplication_filter.params.db_path` | str | `./.similarity_db` | No | Path to the similarity database used to cache embeddings. | +| `deduplication_filter.params.model` | str | `text-embedding-v4` | Yes | Embedding model used to compute similarity. | +| `deduplication_filter.params.api_key` | str | `null` | No | API key. If `null`, it will be loaded from the `DASHSCOPE_API_KEY` environment variable. | +| `deduplication_filter.params.base_url` | str | `https://dashscope.aliyuncs.com/compatible-mode/v1` | No | Base URL for the embedding API. | diff --git a/docs/en/data_pipeline.md b/docs/en/data_pipeline.md new file mode 100644 index 00000000..e128209e --- /dev/null +++ b/docs/en/data_pipeline.md @@ -0,0 +1,173 @@ +# Task Reader + +AgentJet loads training tasks from various data sources through Task Reader. This page covers the Task schema definition and different built-in Task Readers for common scenarios. + +--- + +## Overview + +In agent training, all training data must be represented as **tasks** following a unified schema. + +!!! info "Key Concepts" + - **Unified Schema**: All tasks conform to the `Task` structure regardless of source + - **Multiple Sources**: Load from local files, HuggingFace datasets, interactive environments, or auto-generate new tasks + - **Automatic Routing**: The framework selects the appropriate reader based on `ajet.task_reader.type` + +``` +Data Source → Task Reader → Unified Task Schema → Training Pipeline +``` + +--- + +## Task Schema + +All training tasks must be defined according to the following structure: + +```python +class Task(BaseModel): + main_query: str = Field(default="") + init_messages: List[dict] = Field(default=[]) + task_id: str = Field(default="") + env_type: str = Field(default="") + metadata: dict = Field(default_factory=dict) +``` + +### Field Descriptions + +| Field | Type | Description | +|-------|------|-------------| +| `main_query` | `str` | The main instruction or question for the agent to solve | +| `init_messages` | `List[dict]` | Initial conversation messages (e.g., system prompts). Each must have `role` and `content` fields | +| `task_id` | `str` | Unique identifier for the task | +| `env_type` | `str` | Environment type (e.g., "math", "appworld") | +| `metadata` | `dict` | Additional context information (e.g., reference answers for reward calculation) | + +### Example Task + +```json title="example_task.json" +{ + "main_query": "What is 15 * 23?", + "init_messages": [ + { + "role": "system", + "content": "You are a helpful math assistant." + } + ], + "task_id": "math_001", + "env_type": "math", + "metadata": { + "answer": "345", + "difficulty": "easy" + } +} +``` + +!!! tip "Best Practices" + - Use `metadata` to store information needed for reward computation (e.g., reference answers, scoring rubrics) + - Keep `main_query` clear and concise + - Use `init_messages` for system prompts or few-shot examples + +--- + +## Built-in Task Readers + +AgentJet provides multiple built-in Task Readers for different scenarios. The framework automatically routes to the correct reader based on `ajet.task_reader.type`. + +### Quick Selection Guide + +
+

JSONL File

You have prepared task data in JSONL format locally.

+

HuggingFace

Load tasks from HuggingFace Hub (e.g., GSM8K, MATH).

+

EnvService

Tasks come from a running environment service.

+
+ +--- + +### 1. JSONL File Reader + +**When to use:** You have prepared training tasks in JSONL format locally. + +=== "Configuration" + + ```yaml title="config.yaml" + ajet: + task_reader: + type: jsonl_dataset_file + jsonl_dataset_file: + training: + file_path: "data/train.jsonl" + validation: + file_path: "data/val.jsonl" + ``` + +=== "JSONL Format" + + Each line should be a JSON object conforming to the Task schema: + + ```json title="data/train.jsonl" + {"main_query": "Solve: x + 5 = 12", "task_id": "algebra_01", "env_type": "math", "metadata": {"answer": "7"}} + {"main_query": "What is the capital of France?", "task_id": "geo_01", "env_type": "qa", "metadata": {"answer": "Paris"}} + ``` + +!!! note "How it works" + - Reads tasks line-by-line from specified JSONL files + - Automatically validates against Task schema + - Supports separate training and validation splits + +--- + +### 2. HuggingFace Dataset Reader + +**When to use:** Load tasks from HuggingFace Hub datasets (e.g., GSM8K, MATH). + +```yaml title="config.yaml" +ajet: + task_reader: + type: huggingface_dat_repo + huggingface_dat_repo: + dataset_path: "gsm8k" # HF dataset repo name + dataset_name: "main" # Optional: dataset subset name + training_split: "train" # Training split name + validation_split: "test" # Validation split name +``` + +!!! note "How it works" + - Downloads dataset from HuggingFace Hub using `datasets` library + - Automatically maps dataset fields to Task schema + - Caches downloaded data locally for faster subsequent runs + +--- + +### 3. EnvService Reader + +**When to use:** Tasks are provided by an interactive environment service (e.g., AppWorld, RL gym environments). + +```yaml title="config.yaml" +ajet: + task_reader: + type: env_service + env_service: + env_type: "appworld" # Environment type + env_url: "http://127.0.0.1:8080" # Service URL + env_action_preference: code # Action format: code/text/box + training_split: train + validation_split: dev +``` + +!!! note "How it works" + - Connects to a running environment service via HTTP + - Pulls task instances from the environment + - Supports dynamic task generation from interactive environments + +!!! example "Use Cases" + - Training agents in simulated environments (e.g., FrozenLake, game environments) + - Complex interactive scenarios where tasks are generated dynamically + +--- + +## Next Steps + + diff --git a/docs/en/debugging_guide.md b/docs/en/debugging_guide.md new file mode 100644 index 00000000..ff7563a2 --- /dev/null +++ b/docs/en/debugging_guide.md @@ -0,0 +1,106 @@ +In this tutorial, we introduce the way to debug the workflow and the training algorithms. + + + +## Workflow Debugging (--backbone=debug) + +1. Install VSCode and connect to GPU server. + +VSCode is a open-source software and provides all debugging plugins for free. Therefore, we choose VSCode as our debugging platform. + +VSCode can connect to remote ssh server and operate it as if it is your local machine. +For more details, please refer to VSCode official documents. + +2. Install VSCode Python Extension Bundle + + +3. Create `.vscode/launch.json`. If `.vscode` does not exists yet, create it. + + +4. Copy and paste the following configuration into `launch.json` + +```json +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--conf", "./path/to/yaml.yaml" + ], + "env": {} + } + ] +} +``` + +5. Modify `./path/to/yaml.yaml` field to your task yaml. + + +6. For more sophisticated task with additional external service, add env variables or more args. For example, if your original training command is: + +```bash +export DASHSCOPE_API_KEY="sk-abcdefg" +ajet --conf tutorial/example_appworld/appworld.yaml --with-appworld --backbone='verl' +``` + +Then, the modified launch.json will be + +```json +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", // verl -> debug + "--conf", "tutorial/example_appworld/appworld.yaml", + "--with-appworld", + ], + "env": { + "DASHSCOPE_API_KEY": "sk-abcdefg" + } + } + ] +} +``` + +7. Press `F5` to start debugging. + +8. You can set breakpoint inside the workflow to observe program execution now. + + +## General Debugging (Ray Distributed Debugger) + +1. Install the Ray Distributed Debugger extension in VSCode. + +2. In AgentJet project: + + 2-1. In the place your want to set a conditional breakpoint, write + `from ajet import bp; bp("TAG_1")` + + 2-2. When launching the training process, add `--debug` as commandline argument + `ajet --conf your_config.yaml --debug="TAG_1"` + + 2-3. Open Tab "Ray Distributed Debugger" in VSCode, and just wait until the breakpoint is hit. + + +## Comparison + +| Feature | Workflow Debugging | General Debugging (Ray) | +| :--- | :--- | :--- | +| **Backend** | `debug`, `tinker` | `verl`, `trinity` | +| **Reboot Speed** | Very Fast | Slow | +| **Debug Target** | Workflow | Everything | +| **VSCode Extension** | Python | Python + Ray Distributed Debugger | +| **Launch Mode** | `F5` standard launch (via `launch.json`) | Command line execution with `ajet ... --debug="TAG"` | +| **Commandline** | `--backbone=debug` | `--debug="TAG1\|TAG2\|TAG3"` | diff --git a/docs/en/example.md b/docs/en/example.md new file mode 100644 index 00000000..5eebabbb --- /dev/null +++ b/docs/en/example.md @@ -0,0 +1,3 @@ +# Examples + +This section provides practical examples of how to use AgentJet in various scenarios. diff --git a/docs/en/example_app_world.md b/docs/en/example_app_world.md new file mode 100644 index 00000000..dc3aab80 --- /dev/null +++ b/docs/en/example_app_world.md @@ -0,0 +1,210 @@ +# AppWorld + +This tutorial demonstrates how to train an agent to interact with AppWorld and solve complex tasks through tool usage. + +## 1. Overview + +AppWorld is a high-fidelity execution environment of 9 day-to-day apps, operable via 457 APIs, populated with digital activities of 106 people living in a simulated world. The goal is to tune an agent that can effectively navigate and utilize these apps to complete complex tasks. + +This document is organized as follows: + +- Quick Start: run the example with minimal setup +- Understand: workflow loop, configuration, code locations, and reward +- Results: training curve and qualitative cases + +## 2. Quick Start + +### 2.1 Preparation + +First, download and unpack the Appworld services. The script below is idempotent: it clears any existing folder and re-downloads the archive. + +```bash +base_path="/tmp" +export APPWORLD_PATH="${base_path}/pack_all_in_one" +export APPWORLD_SCRIPT="bash EnvService/env_sandbox/appworld.sh" + +rm -rf "${APPWORLD_PATH}" +rm -f ./appworld_pack_v2.tar.gz + +wget -q "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v2.tar.gz" -O appworld_pack_v2.tar.gz +tar -xzf ./appworld_pack_v2.tar.gz -C "${base_path}" +``` + +Then export the environment variables (re-run in every new shell): + +```bash +export BASE_PATH=/tmp +export APPWORLD_PATH="${BASE_PATH}/pack_all_in_one" +export APPWORLD_SCRIPT="bash EnvService/env_sandbox/appworld.sh" +``` + +### 2.2 Start Training + +Run the training script: + +```bash +ajet --conf tutorial/example_appworld/appworld.yaml --with-appworld +``` + +
+Quick Debugging (Optional) + +If you want to breakpoint-debug the workflow/judge locally: + +```bash +# (optional) recommended cleanup before debug +# ajet --kill="python|ray" + +clear && \ +ajet --conf tutorial/example_appworld/math_agent.yaml --backbone='debug' --with-logview +``` + +When `--backbone=debug`, Ray is disabled. You can use a VSCode `.vscode/launch.json` like below: + +```json +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--conf", "./path/to/yaml.yaml" + ], + "env": {} + } + ] +} +``` +
+ +## 3. Understand + +This section explains how the AppWorld example is assembled: workflow, reward, configuration, and code locations. + +### 3.1 Core Process + +The AgentScope workflow code for the AppWorld example is located at `tutorial/example_appworld/appworld.py`. + +The code first defines the AgentScope workflow (set the agent's `model` to `tuner.as_agentscope_model()`): + +```python +agent = ReActAgent( + name="Qwen", + sys_prompt=first_msg["content"], + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=None, + print_hint_msg=False, +) + +env = workflow_task.gym_env + +for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): + # agentscope deal with interaction message + reply_message = await agent(interaction_message) + # env service protocol + obs, _, terminate, _ = env.step( + action={"content": reply_message.content, "role": "assistant"} + ) + # generate new message from env output + interaction_message = Msg(name="env", content=obs, role="user") + # is terminated? + if terminate: + break + if tuner.get_context_tracker().context_overflow: + break +``` + +In the above code: + +- `env.step`: simulates the gym interface. It takes an action as input and returns a four-tuple `(observation, reward, terminate_flag, info)`. +- `tuner.get_context_tracker().context_overflow`: checks whether the current context window has exceeded the token limit. + + +### 3.2 Reward + +In `ajet/task_judge/env_service_as_judge.py`, we read the reward signal from the environment via `env.evaluate(...)`. + +You can also refer to this file to implement your own Judge for your specific task. + +### 3.3 Configuration Details +Copy and modify the key parameters in `tutorial/example_appworld/appworld.yaml`. The parts most relevant to this document are marked with in the yaml file: + +1. **Read tasks** (corresponding config field: `ajet.task_reader`) +2. **Define the workflow** (corresponding config field: `ajet.rollout.user_workflow`) + - Example: if the AgentScope workflow is defined in the `ExampleAgentScopeWorkflow` class in `tutorial/example_appworld/appworld.py` + - Then set `ajet.rollout.user_workflow = "tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow"` +3. **Define the scoring function** (corresponding config field: `ajet.task_judge.judge_protocol`) + - Example: `ajet.task_judge.judge_protocol = "ajet.task_judge.env_service_as_judge->EnvServiceJudge"` +4. **Specify the model** (corresponding config field: `ajet.model.path`) + +```yaml +ajet: + project_name: example_appworld + experiment_name: "read_yaml_name" + task_judge: + # [key] Implement and select the evaluation function + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + model: + # [key] Set the model to be trained + path: YOUR_MODEL_PATH + rollout: + # [key] Implement and select the Agent + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + force_disable_toolcalls: True + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 +``` + +## 4. Results + +### 4.1 Training Curve + +![Training curve (small batch)](https://img.alicdn.com/imgextra/i2/O1CN01toRt2c1Nj8nKDqoTd_!!6000000001605-2-tps-1410-506.png) + +> **Visualization:** Training curves are generated by SwanLab. See [Visualization Tools](./visualization.md) for setup and usage. + +As training progresses, reward increases. This usually means the agent becomes more stable on **two things**: + +* **Following correct API protocols**: it learns to look up API documentation before calling, and uses valid API endpoints instead of hallucinating non-existent ones. +* **Completing multi-step workflows**: it can properly obtain access tokens and chain multiple API calls to accomplish complex tasks. + + +### 4.2 Case Study + +#### Before tuning: + +1. Frequently call non-existent APIs + +![Before tuning](https://img.alicdn.com/imgextra/i1/O1CN015FgjqI20Ip3AJybr0_!!6000000006827-2-tps-1259-683.png) + +The agent hallucinates API names without checking whether they exist, leading to repeated failures. + +2. Fail to follow the instructions to obtain an access token + +![Before tuning](https://img.alicdn.com/imgextra/i1/O1CN01bGZ1s01VyjCSrTJte_!!6000000002722-2-tps-1181-954.png) + +The agent attempts to call protected APIs without first obtaining the required access token, resulting in authentication errors. + +#### After tuning: + +1. Look up the API documentation first, and learn to use valid APIs + +![After tuning](https://img.alicdn.com/imgextra/i4/O1CN01VRIDy922PoKD1bETl_!!6000000007113-2-tps-1180-944.png) + +The agent now checks available APIs before making calls, avoiding hallucinated endpoints. + +2. Learn to obtain an access token correctly + +![After tuning](https://img.alicdn.com/imgextra/i2/O1CN01xiF9UU20h62dyrZ4x_!!6000000006880-2-tps-1182-793.png) + +The agent properly handles the authentication step before accessing protected APIs. + +> **Token-level Visualization:** These detailed logs are generated by Beast-Logger. See [Beast-Logger Usage](./beast_logger.md) for more details. diff --git a/docs/en/example_countdown.md b/docs/en/example_countdown.md new file mode 100644 index 00000000..ff8ec4e3 --- /dev/null +++ b/docs/en/example_countdown.md @@ -0,0 +1,203 @@ +# Countdown + +## 1. Overview + +Countdown is a math puzzle game. Given a list of numbers and a target number, the player needs to use the numbers and the four basic arithmetic operations (addition, subtraction, multiplication, and division) to form an expression that evaluates to the target number. Each number can be used only once, but parentheses can be used freely to change the order of operations. + +## 2. Quick Start + +### 2.1 Preparation +Download the `Jiayi-Pan/Countdown-Tasks-3to4` dataset and split it into training and test sets: + +```bash +python tutorial/example_countdown/prepare_data.py --target=Jiayi-Pan/Countdown-Tasks-3to4 --path=/the/path/to/store/dataset +``` + +The Countdown dataset contains the `target` and `nums` fields and requires custom data formatting logic. For example, when using the `huggingface_dat_repo` task reader method, you need to modify the `_load_dataset_split` method in `ajet/task_reader/hf_dataset_reader.py`: + +```python +task = Task( + main_query=json.dumps({'target': example["target"], 'nums': example["nums"]}), + init_messages=[], + task_id=str(idx), + env_type="no_env", + metadata=example, +) +``` + +### 2.2 Start Training + +Simply run the following command: + +```bash +# It is recommended to kill all ray, vllm, and env_service processes before starting ( python launcher.py --kill="python|ray|vllm" ) +ajet --conf tutorial/example_countdown/countdown.yaml --backbone='verl' +``` + +
+Quick Debugging (Optional) + +If you want to breakpoint-debug the workflow/judge locally: + +```bash +# (optional) recommended cleanup before debug +# ajet --kill="python|ray" + +clear && \ +ajet --conf tutorial/example_countdown/countdown.yaml --backbone='debug' --with-logview +``` + +When `--backbone=debug`, Ray is disabled. You can use a VSCode `.vscode/launch.json` like below: + +```json +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--conf", "./path/to/yaml.yaml" + ], + "env": {} + } + ] +} +``` +
+ +## 3. Understand + +In this section, we will cover the details of this tutorial. + + +### 3.1 Core Process +See details in `tutorial/example_countdown/countdown.py`. You can create new AgentScope Workflow code anywhere in the project. + +- **Define the AgentScope workflow** + +```python +self.agent = ReActAgent( + name="countdown_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + max_iters=2, +) +msg = Msg("user", query, role="user") +result = await self.agent.reply(msg) +``` + +In the AgentScope Workflow,you need to write the key information required by the judge into: + +```python +WorkflowOutput( + reward=None, + metadata={ + "final_answer": final_answer, + "target": target, + "nums": nums, + } +) +``` + +### 3.2 Configuration Details +Copy and modify key parameters in `tutorial/example_countdown/countdown.yaml`. The most relevant parts in the yaml file are marked with symbols. + +1. Read task (corresponds to configuration field `ajet.task_reader`) +2. Define Workflow (corresponds to configuration field `ajet.rollout.user_workflow`) + - Example: If agentscope workflow is defined in `ExampleCountdownLearn` class of `tutorial/example_countdown/countdown.py` + - Then set `ajet.rollout.user_workflow`=`tutorial.example_countdown.countdown->ExampleCountdownLearn` +3. Define scoring function (corresponds to configuration field `ajet.task_judge.judge_protocol`) + - Example: If agentscope workflow is defined in `CountdownAnswerAsJudge` class of `tutorial/example_countdown/countdown_answer_as_judge.py` + - Then set `ajet.task_judge.judge_protocol`=`tutorial.example_countdown.countdown_answer_as_judge->CountdownAnswerAsJudge` +4. Specify model (corresponds to configuration field `ajet.model.path`) + +```yaml +ajet: + task_reader: + type: huggingface_dat_repo # [key] `env_service` or `dataset_file` or `huggingface_dat_repo` or `data_generation` + rollout: + user_workflow: tutorial.example_countdown.countdown->ExampleCountdownLearn # [key] Write and select Agent + task_judge: + # [key] Write and select evaluation function + judge_protocol: tutorial.example_countdown.countdown_answer_as_judge->CountdownAnswerAsJudge + model: + # [key] Set the model to be trained + path: YOUR_MODEL_PATH +``` + +### 3.3 Code Map + +- `tutorial/example_countdown/countdown.py`: defines the AgentScope workflow (e.g., `ExampleCountdownLearn`). +- `tutorial/example_countdown/countdown.yaml`: wires together task reader, workflow, judge, and model. + +### 3.4 Reward/Evaluation Mechanism +A simple Judge is provided in `tutorial/example_countdown/countdown_answer_as_judge.py`. You can create new Judge code anywhere in the project. + +Judge input parameters include: + +``` +workflow_task: Task information(if reference answer is included, it can be retrieved from here) +workflow_output: Task information output (final_answer needs to be added manually) +``` + +Judge return values: + +- raw_reward +- is_success + +## 4. Results + +### 4.1 Training Curves/Metrics + +![Tracing curve](https://img.alicdn.com/imgextra/i4/O1CN01TtaeD91rnfBF736Zu_!!6000000005676-2-tps-1328-630.png) + +> **Visualization:** Training curves are generated by SwanLab. See [Visualization Tools](./visualization.md) for setup and usage. + +The upward trend in reward reflects improvement in two key areas: + +* **Format compliance**: the agent learns to emit well-formed `` JSON and wrap answers in `\boxed{...}`. +* **Accuracy improvement**: the agent achieves higher correctness rates in solving problems. + + +### 4.2 Case Study + +Agents are already able to deal with some simple problems at the beginning of the training process; however, small models inevitably produce many answers that fail to fully meet the requirements. In some problems, the agent does not strictly follow the instructions, such as invalid output format, reusing numbers, or missing expression outputs. + + +``` +bad case 1: Invalid output format. +<|im_start|>assistant + +{"name": "generate_response", "arguments": "Let's start by subtracting 28 from 76, which gives us 48. Now we need to get to 55 from 48. We can achieve this by adding 7 from the remaining numbers. However, 7 is not available. The next best number is 69. If we subtract 69 from 78, we get -11. If we subtract -11 from 48, we actually end up adding 11, which isn't one of our numbers. The correct approach would be to add 7 from the remaining numbers, but since 7 isn't available, we can use 69 and adjust our strategy. We can add 7 from the remaining numbers in another way. Let's try subtracting 69 from 78 to get -11, and then add 11 to 48. 11 can be obtained by subtracting 69 from 80, but since 80 isn't available, we'll use the next best approach. We'll subtract 69 from 78 to get -11, and then add 16 to 48 to reach 55. We can get 16 by subtracting 7 from 23, but since 23 isn't available, we need to find an alternative. The best way would be to use 78 - 69 = -11, and then 48 + 11 = 59 which is close but not 55. However, we can achieve 55 by using 78 - 69 + 7 = 55. So, the correct calculation is 78 - 69 + 7 = 55."} +<|im_end|> + +bad case 2: Reusing numbers in expressions. +<|im_start|>assistant + NavParams = {"response": "Let's solve this step by step. We have the numbers 72, 20, and 4, and we need to reach the target number 38 using arithmetic operations.\n\nFirst, let's try using subtraction and addition since 72 is quite large and can help us get closer to our target if we subtract something. We can subtract 20 from 72 to get 52, which is close but still far from 38. Now, we can subtract 4 from 52 to get 48, and then subtract 10 to reach 38. Since we don't have a 10 directly, we can use 20 - 4 = 16, and then subtract 16 from 48 to get 32, and finally add 6 (which is 4 + 2) to reach 38.\n\nLet's try a more efficient approach: 72 - 20 = 52, and 52 - 4 = 48, now we need to reduce it by 10, which is 20 - 14, and 14 can be obtained by 20 - 4 - 2. So, 48 - (20 - 4 - 2) = 48 - 14 = 34, then we need 4 more to reach 38, which is simply 4. Thus, the solution is (72 - 20) - 4 - (20 - 4 - 2) + 4 = 38."} + +{"name": "generate_response", "arguments": {"response": "The solution is \boxed{(72 - 20) - 4 - (20 - 4 - 2) + 4}. Let's break it down: \n1. 72 - 20 = 52\n2. 52 - 4 = 48\n3. 20 - 4 - 2 = 14\n4. 48 - 14 = 34\n5. 34 + 4 = 38\nTherefore, we reach the target number 38."}} +<|im_end|> + +bad case 3: Missing expression outputs. +<|im_start|>assistant + +{"name": "generate_response", "arguments": {"response": "In this case, one straightforward way to reach the target number 64 is by directly using the number 64 itself, since it matches the target exactly. No other calculations are needed."}} +boxing the answer as \boxed{{64}}.<|im_end|> +``` + +However, tuning resolves these issues, as shown in the example below: + +1. **Message 3**: decomposes the problem + emits a ``, and uses `\boxed{}` to output key values +2. **Message 4**: decomposes the problem further, and reuses numbers in expressions. + +![After tuning](https://img.alicdn.com/imgextra/i1/O1CN011HjAxo20DKIcPAPVr_!!6000000006815-2-tps-1658-506.png) +![After tuning](https://img.alicdn.com/imgextra/i4/O1CN01C3kUnV221zjPi30rd_!!6000000007061-2-tps-1650-730.png) + +> **Token-level Visualization:** These detailed logs are generated by Beast-Logger. See [Beast-Logger Usage](./beast_logger.md) for more details. diff --git a/docs/en/example_frozenlake.md b/docs/en/example_frozenlake.md new file mode 100644 index 00000000..972151b8 --- /dev/null +++ b/docs/en/example_frozenlake.md @@ -0,0 +1,168 @@ +# Frozen Lake + +## 1. Overview + +**Frozen Lake** is a classic reinforcement learning task from [Gymnasium](https://gymnasium.farama.org/environments/toy_text/frozen_lake/). + +In this environment, the agent is placed on a randomly generated frozen lake, which consists of safe ice ( _ ), dangerous holes (O), and a goal (G). The agent's position is marked as P. The goal is to navigate from the starting position P to the goal G while avoiding the holes. The agent can move up, down, left, or right, but due to the slippery nature of the ice, there is a probability of moving in an unintended direction. + +This example demonstrates how to create a trainable agent workflow to solve this navigation challenge. + +## 2. Quick Start + +### 2.1 Preparation + +Install the dependencies required for the Frozen Lake: + +```bash +pip install gymnasium[toy_text] +``` + +### 2.2 Start Training + +Use the provided configuration file to quickly start training: + +```bash +ajet --conf tutorial/example_frozenlake/frozenlake_easy.yaml --backbone='verl' +``` + +To try a harder setting: + +```bash +ajet --conf tutorial/example_frozenlake/frozenlake_hard.yaml --backbone='verl' +``` + +
+Quick Debugging (Optional) + +If you want to breakpoint-debug the workflow/judge locally: + +```bash +# (optional) recommended cleanup before debug +# ajet --kill="python|ray" + +clear && \ +ajet --conf tutorial/example_frozenlake/frozenlake_easy.yaml --backbone='debug' --with-logview +``` + +When `--backbone=debug`, Ray is disabled. You can use a VSCode `.vscode/launch.json` like below: + +```json +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--conf", "./path/to/yaml.yaml" + ], + "env": {} + } + ] +} +``` +
+ +## 3. Understand + +### 3.1 Core Process + +This example packages a multi-step environment interaction loop into a trainable `Workflow`: + +- The workflow resets the environment and renders the current grid as a text observation for the agent. +- The agent reads the observation and outputs one of `Up | Down | Left | Right`. +- The environment executes the action, returns the next observation and reward. +- The loop stops on success or when the max step limit is reached. + +### 3.2 Configuration Details + +The key fields in `tutorial/example_frozenlake/frozenlake_easy.yaml` / `frozenlake_hard.yaml` are: + +- `ajet.rollout.user_workflow`: entry point of the workflow class, set to `tutorial.example_frozenlake.frozenlake->FrozenLakeWorkflow`. +- `ajet.rollout.multi_turn.max_steps`: maximum steps per episode (also used by the agent). +- `frozen_lake.frozen_lake_size`: grid size (e.g. 4 for easy, 6 for hard). +- `frozen_lake.is_slippery`: whether the action may slip to unintended directions. + +### 3.3 Code Map + +The `FrozenLakeEnv` class in `tutorial/example_frozenlake/frozenlake.py` wraps the Gymnasium Frozen Lake environment, mainly exposing the `step` and `reset` methods. + +- The `step` method returns the next state (observation), reward, done flag, and additional info based on the agent's action. + - observation: The state of the lake after the agent moves, represented as a string, e.g.: + ``` + _ _ G + _ _ _ + P O O + ``` + - reward: The reward received after each move. The agent receives 1 for reaching the goal G, otherwise 0. + - done: Boolean value. True if the agent reaches the goal or falls into a hole, otherwise False. + - info: Additional information. + +- The `reset` method regenerates the lake environment based on user parameters. + +The `FrozenLakeAgent` class in `tutorial/example_frozenlake/frozenlake.py` implements the agent's decision logic, mainly through the `step` method, which takes the current environment observation as input and returns the chosen action. The core is a ReActAgent. + +```python +class FrozenLakeAgent: + + def __init__(self, model: ModelTuner, max_steps: int = 20): + self.agent = ReActAgent( + name="frozenlake_agent", + sys_prompt=SYSTEM_PROMPT, + model=model, + formatter=DashScopeChatFormatter(), + max_iters=2, + ) + # other initialization code + + async def step(self, current_observation: str) -> str: + # Step 1: Build user prompt based on current_observation + # Step 2: Call ReActAgent to get raw response + # Step 3: Parse response and return action +``` + +The `FrozenLakeWorkflow` class in `tutorial/example_frozenlake/frozenlake.py` integrates the environment and agent, mainly exposing the `execute` method. + +```python +class FrozenLakeWorkflow(Workflow): + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # init agent and env + # reset environment and get initial `observation_str` + rewards = [] + for _ in range(self.max_steps): + action = await self.agent.step(observation_str) + observation_str, reward, done, info = self.env.step(action) + rewards.append(reward) + if done: + break + return WorkflowOutput( + reward=sum(rewards), + ) +``` + +### 3.4 Reward + +- The per-episode reward is the sum of step rewards. +- In this FrozenLake setup, the agent gets `+1` when reaching the goal, otherwise `0`. +- The workflow also returns metadata such as `terminate_reason` (`success`, `agent_error`, `max_steps_reached`) and `step_count`. + +## 4. Results + +### 4.1 Training Curve + +![](https://img.alicdn.com/imgextra/i3/O1CN01ZfICUr1Rs4zXrPPXQ_!!6000000002166-2-tps-858-614.png) + +> **Visualization:** Training curves are generated by SwanLab. See [Visualization Tools](./visualization.md) for setup and usage. + +Since the reward is sparse (only `+1` when reaching the goal, otherwise `0`), the rising reward curve directly reflects an **increasing success rate**—the agent reaches the goal G more often. + +This improvement typically comes from two aspects: + +* **Better spatial reasoning**: the agent learns to parse the grid and identify the relative positions. +* **Safer path planning**: it avoids falling into holes and takes more reliable routes toward the goal. diff --git a/docs/en/example_learning_to_ask.md b/docs/en/example_learning_to_ask.md new file mode 100644 index 00000000..d5a17abe --- /dev/null +++ b/docs/en/example_learning_to_ask.md @@ -0,0 +1,224 @@ +# Learning to Ask + +Train an agent to **ask the next best question** (instead of answering directly). Rewards come from an **LLM-as-a-judge** that scores whether the question is helpful and relevant. + + + +### 1. Overview + + +In **Learning to Ask**, each training sample is a short **doctor–patient chat history**. The agent outputs **one next question** the doctor should ask next (optionally with multiple-choice answers), rather than giving diagnosis or treatment. + +![](https://img.alicdn.com/imgextra/i4/O1CN01m9WJCM1WJL1aJCSaS_!!6000000002767-2-tps-1024-559.png) +
Why "Learning to Ask" matters. Left: LLM gives a diagnosis with too little information. Right: LLM asks clear follow-up questions before concluding, which feels more reassuring.
+ + +This tutorial is organized in two steps: +1) **Run it**: start training with the default YAML config. +2) **Understand & customize**: dataset preprocessing, workflow (ExampleLearn2Ask), and reward (reward_fn + llm_reward). + +--- + +### 2. Quick Start + +#### 2.1 Preparation + +Download the [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv) dataset from HuggingFace and place files in: `data/realmedconv` + +Then preprocess it: + +```bash +export DASHSCOPE_API_KEY=your_api_key + +cd tutorial/example_learn2ask/data_preprocess +./run_process.sh data/realmedconv +``` + +After preprocessing, you should have: `train.jsonl` and`test.jsonl`。 + + +#### 2.2 Start Training + +```bash +ajet --conf tutorial/example_learn2ask/learn2ask.yaml --backbone='verl' +# or +ajet --conf tutorial/example_learn2ask/learn2ask.yaml --backbone='trinity' --with-ray +``` + +
+Quick Debugging (Optional) + +Run Ray locally without enabling it for faster iteration: + +``` +bash ajet --conf tutorial/example_learn2ask/learn2ask.yaml --backbone='debug' --with-logview + +``` + +If the results are incorrect, the quickest troubleshooting points include: whether the data path exists, whether an API key has been set if judge requires it, and whether the workflow classpath in `user_workflow` matches the location of your code. + +
+ + +### 3. Understand + +#### 3.1 What happens each step + +This tutorial trains a model to **ask the next best question** from a short doctor–patient chat history. Concretely, each training step takes one conversation context from `train.jsonl`, asks the agent to generate **exactly one follow-up question** (optionally with answer options), and then uses an LLM judge to score whether that question is useful and relevant. AgentJet uses this score as the reward signal to update the policy, so the model gradually learns to ask better questions instead of answering directly. + +#### 3.2 YAML Configuration + +The whole example is “wired” in the YAML and implemented in one file. In the YAML, `task_reader` provides the dataset split, `rollout.user_workflow` tells AgentJet which workflow to run for each sample, and `task_judge` provides the reward entry that wraps the LLM judge. The `model` section decides which pretrained backbone you start from. + +```yaml +ajet: + task_reader: + type: dataset_file + # train_path: data/realmedconv/train.jsonl + # test_path: data/realmedconv/test.jsonl + + rollout: + # For each sample: conversation context -> one next question + user_workflow: tutorial.example_learn2ask.learn2ask->ExampleLearn2Ask + + task_judge: + # Reward function used by the trainer (internally calls the LLM judge) + # judge_protocol: tutorial.example_learn2ask.learn2ask->reward_fn + + model: + # pretrained backbone to start from + # path: /path/to/your/model +``` + +#### 3.3 Code Map + +At the code level, everything is implemented in `tutorial/example_learn2ask/learn2ask.py`: + +* `ExampleLearn2Ask` defines the workflow: how the dialogue context is converted into the agent’s prompt/input, and what output format is expected (one follow-up question, optionally with choices). +* `reward_fn` defines how to convert the judge’s feedback into a scalar reward used for training. + +We provide two implmentations of the agent based on AgentScope and langchain: + +=== "AgentScope" + + ```python + # create the agent + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + toolkit=None, + memory=InMemoryMemory(), + max_iters=1, + ) + self.agent.set_console_output_enabled(False) + + # convert the messages to agent scope format and send to the agent + msg = [ + # Msg("system", system_prompt, role="system"), + *[Msg(name=x["role"], content=x["content"], role=x["role"]) for x in messages] + ] + result = await self.agent.reply(msg) + if isinstance(result.content, str): + response = result.content + elif isinstance(result.content, list): + response = result.content[0]["text"] # type: ignore + else: + raise NotImplementedError(f"do not know how to handle {type(result.content)}") + reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info) + return WorkflowOutput(reward=reward) + ``` + +=== "Langchain" + + ```python + # get the trainable llm + llm_info=tuner.as_oai_baseurl_apikey() + + # create the langchain agent + llm=ChatOpenAI( + base_url=llm_info.base_url, + api_key=lambda:llm_info.api_key, + ) + agent=create_agent( + model=llm, + system_prompt=system_prompt, + ) + + # build messages and send to the agent + msg=[ + {"role": x["role"], "content": x["content"]} for x in messages + ] + result = agent.invoke({ + "messages": msg, # type: ignore + }) + + response = result["messages"][-1].content + reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info) + return WorkflowOutput(reward=reward) + ``` + +#### 3.4 Reward + +`llm_reward` is the LLM-as-a-judge called inside `reward_fn` to score the model output. The evaluation follows these rules: + +- It **only evaluates the doctor’s last message** (doctor’s last message), and does not consider earlier doctor turns. +- It outputs two scores: **Format Score** + **Content Score** (scored separately, then combined by `reward_fn` into the training reward). + +**Format Score**: scored by the number of questions in the doctor’s last message +- 1.0: exactly **one question**, or correctly output `` when no question is needed +- 0.5: **two questions** +- 0.0: **three or more questions** + +**Content Score**: scored by whether the question targets the “missing information” in `Reference Information` (i.e., information the doctor does not yet know) +- 1.0: the question **directly asks about** an item in `Reference Information`, or correctly end the conversation when no more information is needed +- 0.1: the question is too generic (a general question that could apply to almost any symptom) +- 0.0: the question is **irrelevant** to the missing items in `Reference Information` +- Additionally: **ambiguous or uninformative questions are treated as low-quality** (e.g., unclear references), and will typically receive a score of 0 or close to 0 + +### 4. Results + +#### 4.1 Training Curve + +![curve](https://img.alicdn.com/imgextra/i2/O1CN01STeE8G1ricp6jySto_!!6000000005665-2-tps-866-614.png) + +> **Visualization:** Training curves are generated by SwanLab. See [Visualization Tools](./visualization.md) for setup and usage. + +This training curve reflects the trend of the reward over time. The reward includes: **Format Score** (whether the model asks exactly one question) and **Content Score** (whether it directly asks about the missing key information in `Reference Information`). + +Therefore, a rising curve indicates the model gradually learns to produce **one clear question** and to hit missing information more often, rather than asking generic questions or producing ambiguous/uninformative questions. + +#### 4.2 Case Study + +**Before training** + +```text +User: Bitten by a midge + +Agent: Has irritation subsided past initial acute swelling phase afterwards unnoticed scratched overly causing secondary bacterial dermatitis chance affecting skin rendering needed topical antibiotic application options among azelaic acid cream!? +``` + +The agent provided a slightly related but completely unreasonable answer and failed to form any effective question. + +**During training (~50%)** + +```text +User: Bitten by a midge + +Agent: Have these itchy bumps appeared soon after getting bitten?(A.) Occasionally immediately afterward ; B fkknore ×later throughout day )/) +``` + +The agent asked a more reasonable follow-up, attempting to gather symptoms and offering options. + +**After training** + +```text +User: Bitten by a midge + +Agent: Has itching or reddening appeared around this bite site recently without healing despite scratching frequently?(A,Yes indeed worse every evening B,No pain but small swollen spots persist day after incident/) +``` + +The question becomes more precise and informative, guiding the user to provide clinically relevant details. + +> To learn more about the task and results on larger models, refer to [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). diff --git a/docs/en/example_math_agent.md b/docs/en/example_math_agent.md new file mode 100644 index 00000000..e67c24b1 --- /dev/null +++ b/docs/en/example_math_agent.md @@ -0,0 +1,471 @@ +# Math Agent + +Train a **tool-using Math Agent** (ReAct + Python executor) to solve GSM8K-style math problems. Rewards come from a **judge** that checks final-answer correctness. + + +## Overview + +In Math Agent, each training sample is a math word problem (e.g., GSM8K). The agent learns to reason step by step (ReAct-style), call a Python tool when computation is needed, and produce a final answer that matches the reference. + + +This tutorial is organized into the following sections: + +- [**Run this tutorial**: Download the dataset and start training with the default YAML config.](#quick-start) +- [**Understand & customize**: Read the workflow and the judge/reward logic.](#explain) +- [**Training Curve**: Compare the training curlve.](#culve) + + + + + + + + + + +## Quick Start {#quick-start} + +### Prepare Dataset + +Download the `openai/gsm8k` dataset: + +```bash +python scripts/download_dataset.py --target=openai/gsm8k --path=/the/path/to/store/dataset +``` + +### Start Training + +```bash +# (optional) recommended cleanup before training +# ajet --kill="python|ray|vllm" + +ajet --conf tutorial/example_math_agent/math_agent.yaml --backbone='verl' +``` + +??? tip "Quick Debugging (Optional)" + If you want to breakpoint-debug the workflow/judge locally: + + ```bash + # (optional) recommended cleanup before debug + # ajet --kill="python|ray" + + clear && \ + ajet --conf tutorial/example_math_agent/math_agent.yaml --backbone='debug' --with-logview + ``` + + When `--backbone=debug`, Ray is disabled. You can use a VSCode launch config: + + ```json title=".vscode/launch.json" + { + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--conf", "./path/to/yaml.yaml" + ], + "env": {} + } + ] + } + ``` + + + + + + + + + + + + + +## Understanding the Training Pipeline {#explain} + +### Pipeline Abstraction + +
+
Training Step Flow
+ +
+
    +
  1. Load one problem + +Load a math problem from the dataset via `task_reader`.
  2. +
  3. Run the Workflow + +Build the prompt, let the ReActAgent call Python tools, and extract the final answer.
  4. +
  5. Return result as `WorkflowOutput` + +Return `WorkflowOutput(reward=None, metadata={"final_answer": final_answer})`. (reward=None because we want to compute reward outside the workflow)
  6. +
  7. Run the judge + +Compare `final_answer` with reference, compute `raw_reward` and `is_success`.
  8. +
+
+
+ +### YAML Configuration + +Most wiring happens in `tutorial/example_math_agent/math_agent.yaml`: + +=== "AgentScope" + + ```yaml title="math_agent.yaml" + ajet: + task_reader: + type: huggingface_dat_repo # also supports: dataset_file / env_service + + rollout: + user_workflow: tutorial.example_math_agent.math_agent->ExampleMathLearn + + task_judge: + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAndLlmAsJudge + + model: + path: YOUR_MODEL_PATH + ``` + +=== "OpenAI" + + ```yaml title="math_agent.yaml" + ajet: + task_reader: + type: huggingface_dat_repo # also supports: dataset_file / env_service + + rollout: + user_workflow: tutorial.example_math_agent.math_agent_oai_sdk->ExampleMathLearn + + task_judge: + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAndLlmAsJudge + + model: + path: YOUR_MODEL_PATH + ``` + +=== "Raw HTTP" + + ```yaml title="math_agent.yaml" + ajet: + task_reader: + type: huggingface_dat_repo # also supports: dataset_file / env_service + + rollout: + user_workflow: tutorial.example_math_agent.math_agent_raw_http->ExampleMathLearn + + task_judge: + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAndLlmAsJudge + + model: + path: YOUR_MODEL_PATH + ``` + +=== "langchain" + + ```yaml title="math_agent.yaml" + ajet: + task_reader: + type: huggingface_dat_repo # also supports: dataset_file / env_service + + rollout: + user_workflow: tutorial.example_math_agent.math_agent_langchain->ExampleMathLearn + + task_judge: + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAndLlmAsJudge + + model: + path: YOUR_MODEL_PATH + ``` + + +!!! warning "user_workflow assignment" + - As you have noticed, `user_workflow: tutorial.example_math_agent.math_agent_langchain->ExampleMathLearn` means, AgentJet will try to import `ExampleMathLearn` from `${WorkingDir}/tutorial/example_math_agent/math_agent_langchain.py`. (**Dot import**) + - If you prefer absolute path, or you workflow is not in python search path, you can also use the alternative way to import your workflow `user_workflow: /path/to/ajet/tutorial/example_math_agent/math_agent_langchain.py->ExampleMathLearn`. (**Path import**) + - Both **dot import** (dot-to-module) and **path import** (path-to-source-code) is good. But **dot import** is recommended as it is more pythonic. + + + + +| Field | Description | +|-------|-------------| +| `task_reader` | Where tasks come from | +| `user_workflow` | Which workflow runs per sample | +| `judge_protocol` | Which judge computes rewards | +| `model.path` | Pretrained model to fine-tune | + +### Code Walkthrough + +**Workflow:** `tutorial/example_math_agent/math_agent.py` + +=== "AgentScope" + + ```python title="Workflow Sketch" + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=model_tuner, # trainer-managed model wrapper + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), + ) + + msg = Msg("user", init_messages[0]["content"], role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # IMPORTANT: provide final answer to the judge via WorkflowOutput metadata + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) + ``` + +=== "OpenAI" + + ```python title="Workflow Sketch" + client = tuner.as_raw_openai_sdk_client() + + # call 1: get response with tool call + messages = [ + { "role": "system", "content": self.system_prompt }, + { "role": "user", "content": query } + ] + reply_message: ChatCompletion = await client.chat.completions.create(messages=messages, tools=self.available_functions) + if (reply_message.choices[0].message.content): + messages.append({ + "role": "assistant", + "content": reply_message.choices[0].message.content + }) + + # If the model called a tool + if (reply_message.choices[0].message) and (reply_message.choices[0].message.tool_calls): + tool_calls: list[ChatCompletionMessageToolCall] = reply_message.choices[0].message.tool_calls + for tool_call in tool_calls: + if tool_call.function.name == "execute_python_code": + arguments = json.loads(tool_call.function.arguments) + + def sync_wrapper(): + import subprocess + import sys + process = subprocess.run( + [sys.executable, "-c", arguments["code"]], + timeout=arguments.get("timeout", 300), + capture_output=True, + text=True + ) + return process.stdout + + result = await asyncio.to_thread(sync_wrapper) + tool_result_message = { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_call.function.name, + "content": json.dumps({ + "return_code": str(result), + }) + } + messages.append(tool_result_message) + + # Step 3: Make a follow-up API call with the tool result + final_response: ChatCompletion = await client.chat.completions.create( + messages=messages, + ) + final_stage_response = final_response.choices[0].message.content + else: + final_stage_response = reply_message.choices[0].message.content + + + return WorkflowOutput(reward=None, metadata={"final_answer": final_stage_response}) + ``` + + +=== "Raw HTTP" + + ```python title="raw http" + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + + # take out query + query = workflow_task.task.main_query + + messages = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": query + } + ] + + # use raw http requests (non-streaming) to get response + response = requests.post( + f"{base_url}/chat/completions", + json={ + "model": "fill_whatever_model", # Of course, this `model` field will be ignored. + "messages": messages, + }, + headers={ + "Authorization": f"Bearer {api_key}" + } + ) + final_answer = response.json()['choices'][0]['message']['content'] + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) + ``` + + +=== "Langchain" + + ```python title="langchain" + # tuner to api key + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + + from langchain_openai import ChatOpenAI + llm=ChatOpenAI( + base_url=base_url, + api_key=lambda:api_key, + ) + agent=create_agent( + model=llm, + system_prompt=self.system_prompt, + ) + + # take out query + query = workflow_task.task.main_query + + response = agent.invoke({ + "messages": [ + { + "role": "user", + "content": query + } + ], + }) + + final_answer = response['messages'][-1].content + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) + ``` + +!!! warning "Important" + - User should put all elements necessary for reward computation in `WorkflowOutput.metadata`, + so the judge can use them. + - In this specific case, `final_answer` is that key element. + + + +### Reward Computation + +The judge receives: + +| Object | Contains | +|--------|----------| +| `workflow_task` | Task info; reference answer from `metadata` | +| `workflow_output` | Workflow result; final answer from `metadata["final_answer"]` | + +!!! tip "Extending the Judge" + If you observe issues like "almost solved but messed up tool-call formatting", you can extend the judge to add: + + - Format penalty (invalid ``) + - Behavior penalty (tool called but no `print`) + - Keep answer correctness as the primary signal + + +### YAML Configuration + +Most wiring happens in `tutorial/example_math_agent/math_agent.yaml`: + +```yaml title="math_agent.yaml" +ajet: + task_reader: + type: huggingface_dat_repo # also supports: dataset_file / env_service + + rollout: + user_workflow: tutorial.example_math_agent.math_agent->ExampleMathLearn + + task_judge: + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAndLlmAsJudge + + model: + path: YOUR_MODEL_PATH +``` + +| Field | Description | +|-------|-------------| +| `task_reader` | Where tasks come from | +| `user_workflow` | Which workflow runs per sample | +| `judge_protocol` | Which judge computes rewards | +| `model.path` | Pretrained model to fine-tune | + + + + + +## Results {#culve} + +### Training Curve + +![Training curve](https://img.alicdn.com/imgextra/i4/O1CN01gzwgLq1fkCnauydEu_!!6000000004044-2-tps-1422-550.png) + +!!! info "Visualization" + Training curves are generated by SwanLab. See [Visualization Tools](./visualization.md) for setup. + +**Interpretation:** As training progresses, reward increases. This usually means the agent becomes more stable at: + +- **Using tools when needed**: Correctly emitting `` and calling `execute_python_code` +- **Producing reliable answers**: Using tool output to produce final answers aligned with reference + +### Case Study: Tool Discipline Improvement + +Before training, the agent may solve many problems but often fails at **tool-call discipline**: + +=== "Bad Cases" + + ```text + # bad case 1: forgot to print the result in python code + + {"name": "execute_python_code", "arguments": {"code": "... height_difference"}} + + + # bad case 2: too impatient — outputs final answer without waiting for tool result + {"name": "execute_python_code", ...} + {"name": "generate_response", "arguments": {"response": "... \\boxed{48} ..."}} + ``` + + These failures are not because the model "can't do math", but because it **does not close the loop** by incorporating the tool execution result. + +=== "Good Case (After Tuning)" + + After tuning, the agent follows a clean 3-stage pattern: + + 1. **Message 3 (assistant)**: Decomposes problem + emits `` with `print(...)` + 2. **Message 4 (tool_response)**: Tool returns execution results + 3. **Message 5 (assistant)**: Reads `stdout` and produces final answer + + ![Good case](https://img.alicdn.com/imgextra/i4/O1CN01v1gGQZ1ftMiil5Cxg_!!6000000004064-2-tps-1367-684.png) + +!!! note "Token-level Visualization" + The colored blocks show token-level sequence visualization from [Beast-Logger](./beast_logger.md): + + - **Yellow tokens**: Excluded from loss computation + - **Blue tokens**: Participate in loss computation (light to dark = high to low logprob) + +--- + +## Next Steps + + diff --git a/docs/en/example_tracing_feedback_loop.md b/docs/en/example_tracing_feedback_loop.md new file mode 100644 index 00000000..4a348402 --- /dev/null +++ b/docs/en/example_tracing_feedback_loop.md @@ -0,0 +1,111 @@ +# Tracing-Feedback Loop + +AgentJet allows you to recycle the chat logs generated during an Agent's execution and continuously improve the Agent through iterative training, which we call **Tracing-Feedback Training**. It provides features + ++ Loading tracing log from agentscope studio database ++ Converting log into formatted data ++ Filtering high-quality samples with custom rubrics/filters ++ Packing samples into datasets for iterative training + + +In the next section, we will demonstrate how to improve an Agent using Tracing-Feedback Training. + +> **AgentScope & Studio Version Compatibility** +> +> It is recommended to use matched versions: +> +> + AgentScope (v1.0.7) +> + Studio (23eb7c0b1185486d1baca36aea0ce8b85ea9de48) +> + +## Setup + +To use tracing logs for training, you are expected to already have an agent built with **agentscope** running in **agentscope-studio** for some time (usually in production), which means you have + +1. Written your agent with [agentscope](https://github.com/agentscope-ai/agentscope). +2. Enabled tracing module following [the doc](https://doc.agentscope.io/tutorial/task_tracing.html). +3. Deployed your agent and collected the database. + +By default, agentscope-studio will store the tracing logs in +`~/AgentScope-Studio/database.sqlite`, containing all recorded dialogues between the user and the agent. + + + +We have prepared a demo agent in `tutorials/example_feedback_tracing/agent_deployed.py`. You can simulate the tracing log with it and get the database file. + +## Start Tracing-Feedback Training + +Once we have the log (`database.sqlite`), we can train a new Agent with Tracing-Feedback Training. + +1. Set the `ajet.task_reader.type` parameter to `tracing` in the configuration file to enable tracing-feedback mode. +2. Configure the `ajet.task_reader.feedback_tracing` section with the database path and filtering options. +3. Configure other training parameters and Rewards as you would in a normal training workflow. + +```yaml +ajet: + # ... + task_reader: + # use tracing log as tasks + type: tracing + feedback_tracing: + # path to the database + base_url: ./tutorial/example_feedback_tracing/database.sqlite + # path where the module will write cache + train_output_path: ./tutorial/example_feedback_tracing/tasks.jsonl + # the model used in filters + alien_llm_model: qwen3-235b-a22b-instruct-2507 + alien_llm_response_length: 2048 + # filters + filters: + # the default filter llm_evaluate + - type: llm_evaluate + enabled: true + params: + # define your rubrics to drop any bad-quality tasks + custom_rubrics: | + 1. Check the answer and drop the task if it does not answer or answer is wrong. + 2. Consider a response is invalid if it does not wrap the final answer in \boxed{}. + # LLM temperature + temperature: 0.5 + # print debug log + print_reason: false + max_thread: 16 +``` + +When everything is ready, start the training with `launcher.py`. + +```bash +ajet --conf tutorial/example_feedback_tracing/example_feedback_tracing.yaml --backbone='verl' +# or +ajet --conf tutorial/example_feedback_tracing/example_feedback_tracing.yaml --backbone='trinity' --with-ray +``` + +After training, we can now deploy the new Agent into production and collect new logs. This workflow enables continuous improvement through iterative tracing-feedback training. + +## Customize + +### Filter + +The module provides Filter to select high-quality samples from logs for training. Users are allowed to customize the specific rubrics of their own tasks. + +To write rubrics, edit the configuration file: + +```yaml +ajet: + # ... + task_reader: + # ... + feedback_tracing: + # ... + filters: + - type: llm_evaluate + enabled: true # enable the filter + params: + # define your rubrics + custom_rubrics: | + 1. Check the answer and drop the task if it does not answer or answer is wrong. + 2. Consider a response is invalid if it does not wrap the final answer in \boxed{}. + temperature: 0.5 + print_reason: false + max_thread: 16 +``` diff --git a/docs/en/example_werewolves.md b/docs/en/example_werewolves.md new file mode 100644 index 00000000..8ff865ea --- /dev/null +++ b/docs/en/example_werewolves.md @@ -0,0 +1,165 @@ +# Werewolves + +This tutorial demonstrates how to train **multiple agents** to play the Werewolves game. + +## 1. Overview + +The Werewolves role-playing game is a typical POMDP (Partially Observable Markov Decision Process) problem. We can train agents in this cooperative multi-agent problem using shared-parameter methods. + +Terms explained: + +- **Partially Observable**: Agents are only able to receive **local information**. One agent cannot obtain others' perception, even if they are teammates. +- **Markov Decision Process**: Making decisions according to current situations. +- **Shared-parameter**: Using one model as policy for multiple agents. But notice agents **share** policy (model parameters) but **do not share** perception (model input). +- **Cooperative multi-agent problem**: Agents have aligned interests (reward). +- **Environment**: We use static **`Qwen3-235B-A22B`** as the brain of opponents. We use **`Qwen2-7B`** as the brain of trainable agents (`trainable_targets`). + +![image](https://img.alicdn.com/imgextra/i2/O1CN012JgVZC2ABczBhAzJs_!!6000000008165-0-tps-2048-2048.jpg) + +This page shows how to use the Werewolves social deduction game as a multi-agent environment to prepare data and environment, write an AgentScope Workflow, configure the reward module (Judge), and complete the full process from local debugging to formal training. + +Scenario Overview + +- Scenario: Classic Werewolves game, including roles such as werewolf, villager, seer, witch, and hunter. +- Goal: Train a specific role (in this example, the `werewolf`) to achieve a higher win rate in games. + +## 2. Quick Start + +Start training with the following command: +``` +# ( ajet --kill="python|ray|vllm" ) +ajet --conf tutorial/example_werewolves/werewolves.yaml --backbone='verl' +``` + +
+Quick Debugging (Optional) + +If you want to breakpoint-debug the workflow/judge locally: + +```bash +# (optional) recommended cleanup before debug +# ajet --kill="python|ray" + +clear && \ +ajet --conf tutorial/example_werewolves/math_agent.yaml --backbone='debug' --with-logview +``` + +When `--backbone=debug`, Ray is disabled. You can use a VSCode `.vscode/launch.json` like below: + +```json +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--conf", "./path/to/yaml.yaml" + ], + "env": {} + } + ] +} +``` +
+ +## 3. Understand + +### 3.1 Core Process + +At a high level, each training iteration follows this flow: + +- The task reader generates a new game setup (players, role assignments, initial state). +- The rollout runs the AgentScope workflow to simulate a full game. +- Agents in `trainable_targets` act by using the trainable model (via `tuner.as_agentscope_model(...)`), while opponents use the fixed model. +- The environment produces rewards / outcomes for the episode. +- Trajectories are collected and passed to the backbone trainer (`verl` or `trinity`) to update the trainable model. + +### 3.2 Configuration Details + +This section corresponds to `tutorial/example_werewolves/werewolves.yaml`. The key configuration items are as follows: +```yaml +ajet: + task_reader: + # random seed to shuffle players + type: random_dummy + task_judge: + # Implement and select the evaluation function + # (in this example you can first set it to null and rely purely on the rollout's internal reward) + judge_protocol: null + model: + # Set the model to be trained + path: YOUR_MODEL_PATH + rollout: + # Select the AgentScope Workflow entry + user_workflow: tutorial.example_werewolves.start->ExampleWerewolves +``` + +### 3.3 Code Map + +- `tutorial/example_werewolves/werewolves.yaml`: connects the task reader, judge, model, and workflow entry. +- `tutorial/example_werewolves/start.py`: the AgentScope workflow implementation (`ExampleWerewolves`). +- `tutorial/example_werewolves/game.py`: the Werewolves game logic implementation. +- `tutorial/example_werewolves/prompt.py`: prompt templates related to the game. +- `tutorial/example_werewolves/structured_model.py`: defines structured output formats for different roles. +- `tutorial/example_werewolves/utils.py`: game state management and helper functions. + +### 3.4 Reward + +When `judge_protocol: null`, training relies on the reward (or win/loss outcome) produced inside the rollout / environment. In this example, the reward is produced in the workflow in `tutorial/example_werewolves/start.py`. + +In `ExampleWerewolves.execute()`, the workflow first runs a full game by calling `werewolves_game(players, roles)`, and obtains `good_guy_win` (whether the good-guy side wins). + +Then it uses a **turn-level sparse win/loss reward**: + +- If `good_guy_win == True` and the training target is not `werewolf` (i.e., you are training a good-guy role), then `raw_reward = 1` and `is_success = True`. +- If `good_guy_win == False` and the training target is `werewolf` (i.e., you are training a werewolf-side role), then `raw_reward = 1` and `is_success = True`. +- Otherwise, the training side did not win: `raw_reward = 0` and `is_success = False`. + +Exception / invalid-behavior penalty: + +- If an exception is thrown during the game (e.g., the game cannot proceed), all trainable targets are penalized uniformly: `raw_reward = -0.1` and `is_success = False`. + +If you need a more fine-grained evaluation (e.g., giving partial credit for key intermediate decisions instead of only win/loss), implement a custom Judge and enable it via `ajet.task_judge.judge_protocol`. + +## 4. Results + +### 4.1 Training Curves + +`Qwen2-7B` is able to reach about 60% win rate in about 20 steps. + +![image](https://img.alicdn.com/imgextra/i3/O1CN01ldZYDT1ZqGLHuwsrS_!!6000000003245-2-tps-2000-839.png) + +> **Visualization:** Training curves are generated by SwanLab. See [Visualization Tools](./visualization.md) for setup and usage. + +As training progresses, win rate increases. This usually means the agent becomes more stable on **two things**: +- **Role-playing consistency**: the agent learns to maintain its werewolf cover under pressure, avoiding self-exposure even when voted out. +- **Social deception skills**: it develops strategies to mislead opponents, sow suspicion among villagers, and implicitly coordinate with teammates. + +### 4.2 Case Study + +#### Behavior Shifts + +Significant role-playing improvement is observed during the experiment. + +1. For example, when voted out, the original model tends to reveal its identity as `werewolf`, but after fine-tuning, the agent will try to cheat its opponents and protect teammates. For example: + +![](https://img.alicdn.com/imgextra/i1/O1CN01v8VqLB1aYEMfzyTHr_!!6000000003341-2-tps-2104-1016.png) + +> **Token-level Visualization:** These detailed logs are generated by Beast-Logger. See [Beast-Logger Usage](./beast_logger.md) for more details. + +2. The agent develops multiple strategies for winning. For example: +- **Misleading opponents**: "Let's keep an eye on the seer and the witch. They could be werewolves trying to hide". +- **Appealing to reason**: "We need to be wary of fake seers and watch for inconsistencies in stories, Player-Y as hunter should act carefully". + +3. Sometimes agents can take advantage of suspicion between non-werewolf players to eliminate opponents. + +![](https://img.alicdn.com/imgextra/i2/O1CN01Sx7wkU23pHyPXyqPH_!!6000000007304-2-tps-968-575.png) + +#### Expanding Qwen2-7B to Qwen2-14B + +![](https://img.alicdn.com/imgextra/i1/O1CN01TLZcQF1FJ1HPbpLfj_!!6000000000465-2-tps-1842-1008.png) diff --git a/docs/en/hardware_related_solution.md b/docs/en/hardware_related_solution.md new file mode 100644 index 00000000..9743d384 --- /dev/null +++ b/docs/en/hardware_related_solution.md @@ -0,0 +1,20 @@ +This document records a list of **Hardware Related** issues for future reference. + +## 1. ncclUnhandledCudaError: Call to CUDA function failed. + +- Problem: + + ```python + File "/root/AgentJet/.venv/lib/python3.10/site-packages/torch/distributed/utils.py", line 322, in _sync_params_and_buffers + dist._broadcast_coalesced( + torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/NCCLUtils.cpp:77, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.26.2 + ncclUnhandledCudaError: Call to CUDA function failed. + Last error: + Cuda failure 1 'invalid argument' + ``` + +- Solution: + + ```bash + export NCCL_NVLS_ENABLE=0 + ``` diff --git a/docs/en/installation.md b/docs/en/installation.md new file mode 100644 index 00000000..4baabcaf --- /dev/null +++ b/docs/en/installation.md @@ -0,0 +1,137 @@ +# Installation Guide + +This document provides a step-by-step guide to installing AgentJet. + +!!! tip "Latest Version Recommended:" + + AgentJet is under active development and iteration. We recommend installing from source to get the latest features and bug fixes. + + +## Prerequisites + +| Requirement | Detail | +|-------------|---------| +| **Python** | 3.10 | +| Package Management | `uv` or `conda` | + + +## Install from Source + +### Step 1: Clone the Repository + +Clone the AgentJet repository from GitHub and navigate into the project directory: + +```bash +git clone https://github.com/modelscope/AgentJet.git +cd AgentJet +``` + +### Step 2: Install Dependencies + +AgentJet supports multiple backbones, you can choose any of them depending on your requirements, or choose all of them to compare the performance. Currently we have `verl` and `trinity`. + +!!! info "Package Manager" + We recommend using `uv` to manage your Python environment as it is incredibly fast. See also [`uv` installation document](https://docs.astral.sh/uv/getting-started/installation/). + + And of course, if you prefer `conda`, you can also install via conda and pip (simply change `uv pip` to `pip`). + +=== "VERL (uv)" + + ```bash + # Install with `verl` training backbone: + + uv venv --python=3.10 + source .venv/bin/activate + uv pip install -e .[verl] + + #`flash-attn` must be installed after other dependencies + uv pip install --verbose flash-attn --no-deps --no-build-isolation --no-cache + ``` + + !!! warning "flash-attn Installation" + - `flash-attn` must be installed **after** other dependencies. + - If you find your machine spend a long time installing flash-attn, ensure a healthy connection to GitHub. + +=== "VERL (conda)" + + ```bash + # Install with `verl` training backbone: + + conda create -n ajet-verl python=3.10 + conda activate ajet-verl + pip install -e .[verl] + + #`flash-attn` must be installed after other dependencies + pip install --verbose flash-attn --no-deps --no-build-isolation --no-cache + ``` + + + !!! warning "flash-attn Installation" + - `flash-attn` must be installed **after** other dependencies. + - If you find your machine spend a long time installing flash-attn, ensure a healthy connection to GitHub. + + +=== "VERL (aliyun)" + + + ```bash + # Install with `verl` training backbone: + + uv venv --python=3.10 + source .venv/bin/activate + uv pip install -i https://mirrors.aliyun.com/pypi/simple/ -e .[verl] + + #`flash-attn` must be installed after other dependencies + uv pip install -i https://mirrors.aliyun.com/pypi/simple/ --verbose flash-attn --no-deps --no-build-isolation --no-cache + ``` + + !!! warning "flash-attn Installation" + - `flash-attn` must be installed **after** other dependencies. + - Ensure a healthy connection to GitHub to install pre-compiled wheels. + - If you find your machine spend a long time installing flash-attn, ensure a healthy connection to GitHub. + - To build faster, export `MAX_JOBS=${N_CPU}`. + + +=== "Trinity" + + ```bash + # Install with `trinity` training backbone for fully asynchronous RFT: + + uv venv --python=3.10 + source .venv/bin/activate + uv pip install -e .[trinity] + uv pip install --verbose flash-attn --no-deps --no-build-isolation --no-cache + ``` + + +=== "Trinity (aliyun)" + + ```bash + # Install with `trinity` training backbone for fully asynchronous RFT: + + uv venv --python=3.10 + source .venv/bin/activate + uv pip install -i https://mirrors.aliyun.com/pypi/simple/ -e .[trinity] + uv pip install -i https://mirrors.aliyun.com/pypi/simple/ --verbose flash-attn --no-deps --no-build-isolation --no-cache + ``` + + +| Backbone | VeRL | Trinity-RFT | +| -------- |-------- | ------------- | +| Core design | Share-GPU actor-rollout engine (colocate) | Async actor-rollout engine | +| Speed | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | +| Scalability | ⭐⭐ | ⭐⭐⭐⭐ | +| Minimum Required GPU Resource | 1 | 2 | +| Training Stability | ⭐⭐⭐⭐ | ⭐⭐⭐ | +| vLLM Version | 0.10.0 | 0.10.0 | + + + + + +## Next Steps + + diff --git a/docs/en/intro.md b/docs/en/intro.md new file mode 100644 index 00000000..c4f93f11 --- /dev/null +++ b/docs/en/intro.md @@ -0,0 +1,114 @@ +# Introduction + +**AgentJet (AJet)** is a cutting-edge, user-friendly agent tuning framework designed to optimize LLM models and agent workflows. + +Simply provide your workflow (built from AgentScope, OpenAI SDK, Langchain, raw HTTP requests, or hybrid of all of them), training data, and reward function, and we will be ready to enhance your agents to their optimal performance! + + +## Features + +AgentJet aims to build a state-of-the-art agent tuning platform for both developers and researchers + +- **Easy and Friendly**. AgentJet helps you tune models behind your agent workflows easily, optimizing your agents for top performance with minimal effort. +- **Rich Tutorial Library**. AgentJet provides a rich library of [examples](https://github.com/modelscope/AgentJet/tree/main/tutorial) as tutorials. +- **Efficient and Scalable**. AgentJet uses [verl] as the default backbone (`--backbone=verl`). However, we also support [trinity](https://github.com/modelscope/Trinity-RFT/) as alternative backbone, accelerating your tuning process via fully asynchronous RFT. +- **Flexible and Fast**. AgentJet supports [multi-agent workflows](workflow.md) and adopts a context merging technique, accelerating training by 1.5x to 10x when the workflow involves multi-turn (or multi-agent) conversations. +- **Reliability and Reproducibility**. Our team keeps track of framework performance across multiple [tasks + major-git-version + training-backbones](https://benchmark.agent-matrix.com/) (under construction, still gathering data, coming soon). + +For advanced researchers, AgentJet also provides high-resolution logging and debugging solutions: + + +- **High-Resolution Logging**: AgentJet allows users to save and inspect token-level rollout details, recording token IDs, token loss masks, and even token logprobs to facilitate workflow development and agent diagnostics. +- **Fast Debugging**: AgentJet also provides the `--backbone=debug` option for the best debugging experience, shortening your wait period from minutes to seconds after code changes and enabling breakpoint debugging in IDEs. + + + + +## Quick Start + +### Installation + +We recommend using `uv` for dependency management. + +=== "Step 1: Clone Repository" + + ```bash + git clone https://github.com/modelscope/AgentJet.git + cd AgentJet + ``` + +=== "Step 2: Setup Environment" + + ```bash + uv venv --python=3.10.16 && source .venv/bin/activate + uv pip install -e .[trinity] + # Note: flash-attn must be installed after other dependencies + uv pip install flash_attn==2.8.3 --no-build-isolation --no-cache-dir + ``` + +- Train the first agent + +You can start training your first agent with a single command using a pre-configured YAML file: + +```bash +ajet --conf tutorial/example_math_agent/math_agent.yaml +``` + +!!! example "Learn More" + See the [Math Agent](./example_math_agent.md) example for detailed explanation. + + +## Example Library {#example-library} + +Explore our rich library of examples to kickstart your journey: + + + +--- + +## Core Concepts + +AgentJet makes agent fine-tuning straightforward by separating the developer interface from the internal execution logic. + +
+AgentJet Architecture +
+ +### 1. The User-Centric Interface + +To optimize an agent, you provide three core inputs: + + + +### 2. Internal System Architecture + +The internal system orchestrates several specialized modules to handle the complexities of RL training and agent interactions. + +| Module | Description | +|--------|-------------| +| **Launcher** | Manages background service processes (Ray, vLLM) and routes the backbone | +| **Task Reader** | Handles data ingestion, augmentation, and filtering | +| **Task Rollout** | Bridges LLM engines and manages the Gym environment lifecycle | +| **Task Runner** | Executes the AgentScope workflow and calculates rewards | +| **Model Tuner** | Forwards inference requests from the workflow to the LLM engine | +| **Context Tracker** | Monitors LLM calls and automatically merges shared-history timelines (**1.5x-10x** efficiency boost) | + +--- + +## Next Steps + + diff --git a/docs/en/platform_comparison.md b/docs/en/platform_comparison.md new file mode 100644 index 00000000..58a12f04 --- /dev/null +++ b/docs/en/platform_comparison.md @@ -0,0 +1,32 @@ + +# Compare AgentJet with Other Platforms with Agentic RL + + +- Multi OSS Training Backbone: Support switching between multiple open-source training backbones quickly. +- Multi OSS Infer Backbone: Support both vLLM and SGLang. +- Low Code Change: Do not require too many edits to convert a user‑defined (multi) agent workflow into trainable workflows. +- Without-GPU (Cloud-Computing): Rollout and power RL training in a laptop without GPU, using Tinker (AgentLightning) or without Tinker (AgentJet-TinkerScript, comming soon) +- Timeline Optimization: Automatically merge shared-history context generated by the same agents to promote training speed. +- Open Bench Platform: Trace baseline environment's performance across git history in different training backbones. +- Multi-Agent Optimization: Deal with sophisticated multi-agent interaction efficiently, automatically clustering and merging samples generated by the same agents. +- High-res Rollout Logging: Integrated with token level rollout trajectory, highlighting token logprob and loss mask for deep‑level research. +- Agentic Framework Compatible: Easy to convert AgentScope and Langchain workflows into trainable workflows. + + +| Feature | AgentJet | AgentJet-TinkerScript | AgentLightning | rLLM | VeRL 0.7.0 | +|--------- |----------|----------------------|---------------- |------|------------ | +| Multi OSS Training Backbone | ✅ | ❌ | ❌ | ❌ | ➖ | +| Multi OSS Infer Backbone | ✅ | ✅ | ✅ | ✅ | ✅ | +| Low Code Change | ✅ | ✅ | ✅ | ✅ | ➖ | +| Without-GPU (Cloud-Computing) | ❌ | ✅ | ✅ | ❌ | ➖ | +| Timeline Optimization | ✅ | ✅ | ❌ | ❌ | ➖ | +| Open Bench Platform | ✅ | ✅ | ❌ | ❌ | ✅ | +| Multiagent Optimization | ✅ | ✅ | ❌ | ❌ | ✅ | +| High-res Rollout Logging | ✅ | ✅ | ❌ | ❌ | ❌ | +| AgentScope Compatible | ✅ | ✅ | ✅ | ✅ | ➖ | +| Langchain Compatible | ✅ | ✅ | ✅ | ✅ | ➖ | + + +!!! Note "" + - ✅ = "supported"; ❌ = "not supported"; ➖ = "not applicable" or "not investigated" + - All projects are quickly evolving. We expect features not supported today will catch up sooner or later. diff --git a/docs/en/quickstart.md b/docs/en/quickstart.md new file mode 100644 index 00000000..fde601d2 --- /dev/null +++ b/docs/en/quickstart.md @@ -0,0 +1,73 @@ +# Quick Start + + +## 1. Testing Pre-define Demo + +AgentJet provides a complete feature set for tuning agents. You can try starting training an agent right away by running a demo: + +```bash +ajet --conf tutorial/example_math_agent/math_agent.yaml +``` + + +## 2. Minimum Example + +Let's begin with the simplest example: a math agent with a tool call. + +```python title="train_math_agent.py" +from ajet import AgentJetJob + +# refer to `https://modelscope.github.io/AgentJet/en/tune_your_first_agent/` on how to write your own workflow +from tutorial.example_math_agent.math_agent_simplify import MathToolWorkflow + +model_path = "YOUR_MODEL_PATH" +job = AgentJetJob(n_gpu=8, algorithm='grpo', model=model_path) +job.set_workflow(MathToolWorkflow) +job.set_data(type="hf", dataset_path='openai/gsm8k') + +# [Optional] Save yaml file for manual adjustment +# job.dump_job_as_yaml('saved_experiments/math.yaml') + +# [Optional] Load yaml file from manual adjustment +# job.load_job_from_yaml('saved_experiments/math.yaml') + +# Start training +tuned_model = job.tune() +``` + +!!! tip "CLI Alternative" + The code above is equivalent to running in terminal: + ```bash + ajet --conf ./saved_experiments/math.yaml + ``` + + +## 3. Compare with Community Training Curves + + + + + + +## 4. Explore Example Gallery + +Explore our rich library of examples to kickstart your journey: + + + + +## 5. Next Steps + + diff --git a/docs/en/setup_ubuntu.md b/docs/en/setup_ubuntu.md new file mode 100644 index 00000000..b9b7f7b4 --- /dev/null +++ b/docs/en/setup_ubuntu.md @@ -0,0 +1,31 @@ +# 1. install docker +``` +sudo apt update +sudo apt install docker docker.io curl +``` + +# 2. intsall nvidia-runtime + +``` sh +# https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#installing-on-ubuntu-and-debian + +# step 1 +curl https://get.docker.com | sh \ + && sudo systemctl --now enable docker + +# step 2 +distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ + && curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ + && curl -s -L https://nvidia.github.io/libnvidia-container/$distribution/libnvidia-container.list | \ + sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ + sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list + +# step 3:Install the nvidia-docker2 package (and dependencies) after updating the package listing: + +sudo apt-get update +sudo apt-get install -y nvidia-docker2 + +# Restart the Docker daemon to complete the installation after setting the default runtime: + +sudo systemctl restart docker +``` diff --git a/docs/en/support_agentscope.md b/docs/en/support_agentscope.md new file mode 100644 index 00000000..e551e4d9 --- /dev/null +++ b/docs/en/support_agentscope.md @@ -0,0 +1,225 @@ +# Supported Agent Frameworks: AgentScope + +This article introduce the way to convert different types of ways to convert your existing workflows into AgentJet workflows. + + +## AgentScope + +1. use `tuner.as_agentscope_model()` to override ReActAgent's model argument + +2. use `tuner.as_oai_baseurl_apikey()` to override OpenAIChatModel's baseurl + apikey argument + +### Explain with examples + +=== "Before Convertion" + + ```python + model = DashScopeChatModel(model_name="qwen-max", stream=False) # ✈️ change here + agent_instance = ReActAgent( + name=f"Friday", + sys_prompt="You are a helpful assistant", + model=model, + formatter=DashScopeChatFormatter(), + ) + ``` + +=== "After Convertion (`as_agentscope_model()`)" + + ```python + model = tuner.as_agentscope_model() # ✈️ change here + agent_instance = ReActAgent( + name=f"Friday", + sys_prompt="You are a helpful assistant", + model=model, + formatter=DashScopeChatFormatter(), + ) + ``` + +=== "After Convertion (`as_oai_baseurl_apikey()`)" + + ```python + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key # the api key contain information, do not discard it + model = OpenAIChatModel( + model_name="whatever", + client_args={"base_url": base_url}, + api_key=api_key, + stream=False, + ) + self.agent = ReActAgent( + name="math_react_agent", sys_prompt=system_prompt, + model=model, # ✨✨ compared with a normal agentscope agent, here is the difference! + formatter=OpenAIChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), max_iters=2, + ) + ``` + + +!!! warning "" + - when you are using the `tuner.as_oai_baseurl_apikey()` api, you must enable the following feature in the yaml configuration. + + ```yaml + + ajet: + ... + enable_experimental_interchange_server: True + ... + + ``` + + + + +### Explain with examples (Full Workflow Code) + + + +=== "Full Code After Convertion (`as_agentscope_model`)" + + ```python + import re + from loguru import logger + from agentscope.message import Msg + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + + def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if ( + hasattr(result, "metadata") + and isinstance(result.metadata, dict) + and "result" in result.metadata + ): + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + + system_prompt = """ + You are an agent specialized in solving math problems with tools. + Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. + You should return your final answer within \\boxed{{}}. + """ + + + class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # run agentscope + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), # ✨✨ compared with a normal agentscope agent, here is the difference! + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # compute reward + reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip() + match = re.search(r"\\boxed\{([^}]*)\}", final_answer) + if match: is_success = (match.group(1) == reference_answer) + else: is_success = False + return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer}) + + ``` + +=== "Full Code After Convertion (`as_agentscope_model`)" + + ```python + import re + from loguru import logger + from agentscope.message import Msg + from agentscope.agent import ReActAgent + from agentscope.formatter import OpenAIChatFormatter + from agentscope.model import OpenAIChatModel + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + + def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if ( + hasattr(result, "metadata") + and isinstance(result.metadata, dict) + and "result" in result.metadata + ): + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + + system_prompt = """ + You are an agent specialized in solving math problems with tools. + Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. + You should return your final answer within \\boxed{{}}. + """ + + + class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # run agentscope + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key # the api key contain information, do not discard it + model = OpenAIChatModel( + model_name="whatever", + client_args={"base_url": base_url}, + api_key=api_key, + stream=False, + ) + self.agent = ReActAgent( + name="math_react_agent", sys_prompt=system_prompt, + model=model, # ✨✨ compared with a normal agentscope agent, here is the difference! + formatter=OpenAIChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # compute reward + reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip() + match = re.search(r"\\boxed\{([^}]*)\}", final_answer) + if match: is_success = (match.group(1) == reference_answer) + else: is_success = False + return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer}) + ``` diff --git a/docs/en/support_http.md b/docs/en/support_http.md new file mode 100644 index 00000000..0bf3ab3d --- /dev/null +++ b/docs/en/support_http.md @@ -0,0 +1,95 @@ +# Without Any Agentic Framework + +Why use the Agent SDKs and all these abstractions? If you want to take control of the foundation of LLM Agents, +in this AI era, you can always start from scratch and build your own "high-scrapers". + +## Http + +- use `tuner.as_oai_baseurl_apikey()` to obtain baseurl + apikey arguments + +### Explain with examples + +=== "Before Convertion" + + ```python + # tuner to api key + base_url = "https://openrouter.ai/api/v1" + api_key = "sk-1234567" + + # take out query + query = workflow_task.task.main_query + + messages = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": query + } + ] + + # use raw http requests (non-streaming) to get response + response = requests.post( + f"{base_url}/chat/completions", + json={ + "model": "fill_whatever_model", # Of course, this `model` field will be ignored. + "messages": messages, + }, + headers={ + "Authorization": f"Bearer {api_key}" + } + ) + final_answer = response.json()['choices'][0]['message']['content'] + ``` + +=== "After Convertion (`as_oai_baseurl_apikey`)" + + ```python + # tuner to api key + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + + # take out query + query = workflow_task.task.main_query + + messages = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": query + } + ] + + # use raw http requests (non-streaming) to get response + response = requests.post( + f"{base_url}/chat/completions", + json={ + "model": "fill_whatever_model", # Of course, this `model` field will be ignored. + "messages": messages, + }, + headers={ + "Authorization": f"Bearer {api_key}" + } + ) + final_answer = response.json()['choices'][0]['message']['content'] + ``` + + + +!!! warning "" + - when you are using the `tuner.as_oai_baseurl_apikey()` api, you must enable the following feature in the yaml configuration. + + ```yaml + + ajet: + ... + enable_experimental_interchange_server: True + ... + + ``` diff --git a/docs/en/support_langchain.md b/docs/en/support_langchain.md new file mode 100644 index 00000000..d1e12890 --- /dev/null +++ b/docs/en/support_langchain.md @@ -0,0 +1,86 @@ +# Supported Agent Frameworks: AgentScope + +This article introduce the way to convert different types of ways to convert your existing workflows into AgentJet workflows. + + +## AgentScope + +1. use `tuner.as_oai_baseurl_apikey()` to override OpenAIChatModel's baseurl + apikey argument + +### Explain with examples + +=== "Before Convertion" + + ```python + from langchain_openai import ChatOpenAI + + + + + # create openai model + llm = ChatOpenAI( + model="gpt-5", + ) + agent=create_agent( + model=llm, + system_prompt=self.system_prompt, + ) + + # take out query + query = workflow_task.task.main_query + + response = agent.invoke({ + "messages": [ + { + "role": "user", + "content": query + } + ], + }) + ``` + +=== "After Convertion (`as_oai_baseurl_apikey`)" + + ```python + from langchain_openai import ChatOpenAI + + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + + llm = ChatOpenAI( + model="whatever", + base_url=base_url, + api_key=lambda:api_key, + ) + agent = create_agent( + model=llm, + system_prompt=self.system_prompt, + ) + + # take out query + query = workflow_task.task.main_query + + response = agent.invoke({ + "messages": [ + { + "role": "user", + "content": query + } + ], + }) + ``` + + + +!!! warning "" + - when you are using the `tuner.as_oai_baseurl_apikey()` api, you must enable the following feature in the yaml configuration. + + ```yaml + + ajet: + ... + enable_experimental_interchange_server: True + ... + + ``` diff --git a/docs/en/support_oaisdk.md b/docs/en/support_oaisdk.md new file mode 100644 index 00000000..b60b03e3 --- /dev/null +++ b/docs/en/support_oaisdk.md @@ -0,0 +1,90 @@ +# Supported Agent Frameworks: AgentScope + +This article introduce the way to convert different types of ways to convert your existing workflows into AgentJet workflows. + + +## AgentScope + +1. use `tuner.as_raw_openai_sdk_client()` to create a openai SDK + +2. use `tuner.as_oai_baseurl_apikey()` to override openai SDK's baseurl + apikey argument + +### Explain with examples + +=== "Before Convertion" + + ```python + import openai + client = openai.OpenAI(api_key='sk-123456') + messages = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": query + } + ] + reply_message: ChatCompletion = await client.chat.completions.create(messages=messages) + final_answer = reply_message.choices[0].message.content + ``` + +=== "After Convertion (`as_raw_openai_sdk_client`)" + + ```python + + client = tuner.as_raw_openai_sdk_client() + messages = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": query + } + ] + reply_message: ChatCompletion = await client.chat.completions.create(messages=messages) + final_answer = reply_message.choices[0].message.content + ``` + + +=== "After Convertion (`as_oai_baseurl_apikey`)" + + ```python + import openai + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + + client = openai.OpenAI(api_key=api_key, base_url=base_url) + + messages = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": query + } + ] + reply_message: ChatCompletion = await client.chat.completions.create(messages=messages) + final_answer = reply_message.choices[0].message.content + ``` + + + + + !!! warning "" + - when you are using the `tuner.as_oai_baseurl_apikey()` api, you must enable the following feature in the yaml configuration. + + ```yaml + + ajet: + ... + enable_experimental_interchange_server: True + ... + + ``` diff --git a/docs/en/task_judger.md b/docs/en/task_judger.md new file mode 100644 index 00000000..0b0df367 --- /dev/null +++ b/docs/en/task_judger.md @@ -0,0 +1,228 @@ +# Task Judger + +Task Judger evaluates agent outputs and assigns rewards during training. This page covers built-in judgers for common scenarios and how to create custom judgers for specific evaluation needs. + +!!! warning "When to use the task judger" + - **Is task judger necessary for all tasks? No**: + - There are two options to generate reward: + - Compute reward **inside** the user-defined workflow (`WorkflowOutput.reward is not None`) + - Compute reward **outside** the user-defined workflow (`WorkflowOutput.reward is None`) + - **Task judger** is how AgentJet handles **out-of-workflow** reward computation. + - Task judger will be **Disabled and Ignored** when the user-defined workflow returned an effective `WorkflowOutput.reward` and `WorkflowOutput.reward != None` + - Task judger will be **Enabled** when the user-defined workflow returned `WorkflowOutput.reward = None`. + - **When to use the task judger**: + - When the user plan to **re-used** the reward function in multiple other workflows in the future. + - When the user want to **decouple** rollout and reward computation logic. + - When the user want to use our [**OpenJudge**](https://github.com/modelscope/OpenJudge) integration to generate [Auto Rubrics reward](https://modelscope.github.io/OpenJudge/building_graders/generate_rubrics_as_graders/). + +## Overview + +A Task Judger evaluates the agent's execution results and returns two values: + +| Return Value | Type | Description | +|--------------|------|-------------| +| `raw_reward` | `float` | Numerical score representing output quality (often 0.0 to 1.0) | +| `is_success` | `bool` | Whether the task was successfully completed | + +These values guide the RL training process, helping agents learn which behaviors produce better outcomes. + + +## Base Interface + +All Task Judgers inherit from `BaseJudge` and implement the `compute_reward` method: + +```python title="base_judge.py" +from ajet.task_judge.base_judge import BaseJudge +from ajet.workflow import WorkflowOutput, WorkflowTask + +class BaseJudge: + def __init__(self, config): + self.config = config + + def compute_reward( + self, + workflow_task: WorkflowTask, + workflow_output: WorkflowOutput + ) -> tuple[float, bool]: + """ + Args: + workflow_task: Contains the task data, including metadata with reference answers + workflow_output: Contains the agent's output, including metadata with generated answers + + Returns: + tuple: (raw_reward: float, is_success: bool) + """ + raise NotImplementedError +``` + + +## Built-in Task Judgers + +AgentJet provides three built-in judgers for common evaluation scenarios: + +### 1. MathAnswerAsJudge + +Evaluates mathematical answers by exact string matching, designed for tasks where answers are formatted in LaTeX `\boxed{}` notation. + +!!! tip "When to use" + - Math problem solving tasks + - Tasks with deterministic, exact answers + - Answers formatted as `\boxed{result}` + +=== "Configuration" + + ```yaml title="config.yaml" + ajet: + task_judge: + judge_type: customized_protocol + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge + ``` + +=== "How it works" + + 1. Extracts the answer from `\boxed{...}` in the agent's output + 2. Compares with the reference answer from `workflow_task.task.metadata["answer"]` + 3. Returns `(1.0, True)` for correct answers, `(0.0, False)` otherwise + +**Required metadata:** + +| Field | Source | Description | +|-------|--------|-------------| +| `final_answer` | `workflow_output.metadata` | Agent's answer with `\boxed{}` format | +| `answer` | `workflow_task.task.metadata` | Reference answer | + + +### 2. CountdownAnswerAsJudge + +Evaluates mathematical equations with partial credit for proper formatting. + +!!! tip "When to use" + - Number puzzle tasks (e.g., Countdown game) + - Tasks where partial credit is appropriate + - Need to reward proper formatting even when answer is wrong + +=== "Configuration" + + ```yaml title="config.yaml" + ajet: + task_judge: + judge_type: customized_protocol + judge_protocol: tutorial.example_countdown.countdown_answer_as_judge->CountdownAnswerAsJudge + ``` + +=== "Scoring" + + | Score | Condition | + |-------|-----------| + | `0.0` | Invalid or missing answer | + | `0.1` | Properly formatted equation but wrong result | + | `1.0` | Correct equation and result | + + +### 3. EnvServiceJudge + +Delegates evaluation to an external environment service, useful for complex interactive environments. + +!!! tip "When to use" + - Tasks with external simulators (e.g., AppWorld) + - Interactive environments with built-in evaluators + +```yaml title="config.yaml" +ajet: + task_judge: + judge_type: customized_protocol + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge +``` + + +## Creating Custom Task Judgers + +For specialized evaluation needs, create your own judger by inheriting `BaseJudge`: + +
+
Custom Judger Steps
+ +
+
    +
  1. Implement Your Judger + +Create a new file with your custom judger class.
  2. +
  3. Configure Your Judger + +Point to your custom class in the YAML configuration.
  4. +
  5. Pass Data to the Judger + +Populate `workflow_output.metadata` with the data your judger needs.
  6. +
+
+
+ +### Step 1: Implement Your Judger + +```python title="tutorial/my_task/my_judge.py" +from ajet.task_judge.base_judge import BaseJudge +from ajet.workflow import WorkflowOutput, WorkflowTask + +class MyCustomJudge(BaseJudge): + def __init__(self, config): + super().__init__(config) + self.threshold = 0.8 + + def compute_reward( + self, + workflow_task: WorkflowTask, + workflow_output: WorkflowOutput + ) -> tuple[float, bool]: + agent_answer = workflow_output.metadata.get("final_answer", "") + reference_answer = workflow_task.task.metadata.get("answer", "") + + similarity = self._compute_similarity(agent_answer, reference_answer) + is_success = similarity >= self.threshold + return similarity, is_success + + def _compute_similarity(self, text1: str, text2: str) -> float: + return len(set(text1.split()) & set(text2.split())) / max( + len(text1.split()), len(text2.split()), 1 + ) +``` + +### Step 2: Configure Your Judger + +```yaml title="config.yaml" +ajet: + task_judge: + judge_type: customized_protocol + judge_protocol: tutorial.my_task.my_judge->MyCustomJudge +``` + +### Step 3: Pass Data to the Judger + +```python title="workflow.py" +class MyWorkflow(Workflow): + async def execute(self, task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + final_answer = await self.agent.reply(msg) + return WorkflowOutput( + reward=None, # Will be filled by the judger + metadata={ + "final_answer": final_answer, + } + ) +``` + + +## Configuration Summary + +```yaml title="config.yaml" +ajet: + task_judge: + judge_type: customized_protocol + judge_protocol: ajet.task_judge.-> +``` + + +## Next Steps + + diff --git a/docs/en/tune_your_first_agent.md b/docs/en/tune_your_first_agent.md new file mode 100644 index 00000000..2ecd9f1c --- /dev/null +++ b/docs/en/tune_your_first_agent.md @@ -0,0 +1,417 @@ +# Tune Your First Agent + +In this document, we demonstrate how to implement and train, from scratch, an agent that can use Python to perform calculations and solve 'gsm8k' math problems. + + + +
+
Training Pipeline Overview
+ +
+
    + +
  1. Define agent workflow + +Create your agent using AgentScope/Langchain/OpenaiSDK or only http requests, wrap it in a Workflow class.
  2. + +
  3. Define reward + +Configure how the agent's outputs are evaluated and scored.
  4. + +
  5. Prepare dataset + +Set up the dataset and configure the task reader.
  6. + + +
  7. Debug (Optional) + +Test your workflow in debug mode before full training.
  8. +
  9. Start training + +Launch the training process and track progress.
  10. +
+
+
+ + + +!!! info "" + Checkout the full code of this example by [clicking here](#full-code) + + + +## Step 1: ✨Define agent Workflow + Reward + + +First of all, create a directory for this training project: + +```bash +tutorial/example_math_agent +├── math_agent.py +└── math_agent.yaml +``` + +Next, define your workflow (or convert an existing workflow). Here we use AgentScope to implement this agent. You can toggle two code before and after convertion to see the difference. If you prefer langchain or openai sdk, [please refer to this article](agent_framework_support.md). + +=== "`math_agent.py` - AgentJet Workflow (After Convertion)" + + ```python title="math_agent.py" + class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # run agentscope + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), # ✨✨ compared with a normal agentscope agent, here is the difference! + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # compute reward + reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip() + match = re.search(r"\\boxed\{([^}]*)\}", final_answer) + if match: is_success = (match.group(1) == reference_answer) + else: is_success = False + return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer}) + + ``` + + +=== "Original Workflow (Before Convertion)" + + ```python title="math_agent.py" + class MathToolWorkflow(object): + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask) -> WorkflowOutput: + # run agentscope + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", sys_prompt=system_prompt, + model=DashScopeChatModel(model='qwen-max'), + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # compute reward + reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip() + match = re.search(r"\\boxed\{([^}]*)\}", final_answer) + if match: is_success = (match.group(1) == reference_answer) + else: is_success = False + return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer}) + + ``` + + + +## Step 2: ✨Prepare dataset + +!!! info "Data Sources" + AgentJet provides multiple ways to read data: + + - Read from local files on disk + - Read from a Hugging Face repo + - Read from an EnvService + + +Download the `openai/gsm8k` dataset: + +```bash +python scripts/download_dataset.py --target=openai/gsm8k --path=/the/path/to/store/dataset +``` + +Now, we have obtained all materials required to train the agent. + + +=== "`math_agent.yaml` - Configuration Yaml" + + ```yaml + # ------------------ main configuration ------------------ + ajet: + project_name: example_math_agent + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` + # effective when `type: huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: 'openai/gsm8k' + training_split: "train" + validation_split: "test" + + task_judge: + # ✨✨✨✨ null, because in this certain case, we write reward function together with workflow + judge_protocol: null + + model: + # ✨✨✨✨ set the model to be trained + path: Qwen/Qwen2.5-7B + + rollout: + user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ write and select workflow + num_repeat: 6 # grpo `n` + tensor_model_parallel_size: 1 # vllm tp + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 100 + total_epochs: 100 + logger: swanlab + + # ------------------ do not modify ------------------ + hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + - file://ajet/default_config/trinity + + # ------------------ do not modify ------------------ + defaults: + - verl_default + - trinity_default + - ajet_default + - _self_ + + ``` + +### Configuration Parameters + +| Category | Parameter | Description | Example Value | +|----------|-----------|-------------|---------------| +| **Project** | `project_name` | Name of the training project | `example_math_agent` | +| **Task Reader** | `type` | Type of data source to read tasks from | `huggingface_dat_repo` (options: `env_service`, `dataset_file`, `huggingface_dat_repo`) | +| | `dataset_path` | Path or identifier of the dataset | `openai/gsm8k` | +| | `training_split` | Dataset split used for training | `train` | +| | `validation_split` | Dataset split used for validation/testing | `test` | +| **Model** | `path` | Path or identifier of the model to be trained | `Qwen/Qwen2.5-7B` | +| **Rollout** | `user_workflow` | Python module path to the workflow class | `tutorial.example_math_agent.math_agent->ExampleMathLearn` | +| | `num_repeat` | Number of rollout repeats per task (GRPO `n` parameter) | `6` | +| | `tensor_model_parallel_size` | vLLM tensor parallelism size | `1` | +| | `max_response_length_in_one_turn` | Maximum token length for a single agent response | `1024` | +| | `max_model_len` | Maximum total context length for the model | `10000` | +| **Data** | `train_batch_size` | Number of tasks per training batch | `100` | +| | `max_prompt_length` | Maximum token length for input prompts | `3000` | +| | `max_response_length` | Maximum token length for model responses | `7000` | +| **Debug** | `debug_max_parallel` | Maximum parallel workers in debug mode | `1` | +| | `debug_first_n_tasks` | Number of tasks to process in debug mode | `1` | +| **Trainer** | `save_freq` | Frequency (in steps) to save model checkpoints | `100` | +| | `test_freq` | Frequency (in steps) to run validation | `100` | +| | `total_epochs` | Total number of training epochs | `100` | +| | `logger` | Logging backend for experiment tracking | `swanlab` | +| **Task Judge** | `judge_protocol` | Protocol for judging task completion | `null` (reward is computed in workflow) | + + +## Step 3: ✨Debug (Optional) + +Before full training, you can run some test in debug mode, using raw base model to test whether bug exists. +We choose VSCode to debug because it is open-source and fast. + + +!!! tip "VS Code Debugging" + - You can create `.vscode/launch.json` for breakpoint debugging: + + ```json + { + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "module": "ajet.launcher", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--conf", "tutorial/example_math_agent/math_agent.yaml" + ], + "env": {} + } + ] + } + ``` + +After `.vscode/launch.json` is created, press `F5` to start debugging. (Do not forget to configure python venv path in VSCode.) + +For more debugging techniques, please refer to [debugging guidelines](debugging_guide.md). + + +## Step 4: ✨Start Training + +After debugging, launch the full training: + +```bash +ajet --conf tutorial/example_math_agent/math_agent.yaml +``` + +!!! success "Output Location" + Training logs and checkpoints will be saved default to: + ``` + ./saved_experiments/{exp_yaml_file_name}/ + ``` + + + +## Full Code {#full-code} + +=== "`tutorial/example_math_agent/math_agent.py` - AgentJet Workflow (After Convertion)" + + ```python + import re + from loguru import logger + from agentscope.message import Msg + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + + def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if ( + hasattr(result, "metadata") + and isinstance(result.metadata, dict) + and "result" in result.metadata + ): + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + + system_prompt = """ + You are an agent specialized in solving math problems with tools. + Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. + You should return your final answer within \\boxed{{}}. + """ + + + class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # run agentscope + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), # ✨✨ compared with a normal agentscope agent, here is the difference! + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # compute reward + reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip() + match = re.search(r"\\boxed\{([^}]*)\}", final_answer) + if match: is_success = (match.group(1) == reference_answer) + else: is_success = False + return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer}) + + ``` + +=== "`tutorial/example_math_agent/math_agent.yaml` - Configuration Yaml" + + ```yaml + # ------------------ main configuration ------------------ + ajet: + project_name: example_math_agent + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` + # effective when `type: huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: 'openai/gsm8k' # '/mnt/data_cpfs/dataset_cache/openai/gsm8k/main' + training_split: "train" + validation_split: "test" + + model: + # ✨✨✨✨ set the model to be trained + path: Qwen/Qwen2___5-7B-Instruct # /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + rollout: + user_workflow: "tutorial/example_math_agent/math_agent.py->MathToolWorkflow" # ✨✨✨✨ write and select workflow + num_repeat: 6 # grpo `n` + tensor_model_parallel_size: 1 # vllm tp + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + + task_judge: + # ✨✨✨✨ null, because in this certain case, we write reward function together with workflow + judge_protocol: null + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 100 + total_epochs: 100 + logger: swanlab + + # ------------------ do not modify ------------------ + hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + - file://ajet/default_config/trinity + + # ------------------ do not modify ------------------ + defaults: + - verl_default + - trinity_default + - ajet_default + - _self_ + + ``` + +## Next Steps + + diff --git a/docs/en/visualization.md b/docs/en/visualization.md new file mode 100644 index 00000000..bc66d52d --- /dev/null +++ b/docs/en/visualization.md @@ -0,0 +1,117 @@ +# Training Visualization + +Monitoring training progress through visualized metrics is essential for understanding model behavior and tuning hyperparameters effectively. + +--- + +## Supported Visualization Tools + +
+

SwanLab ⭐

Modern experiment tracking platform designed for AI research. Recommended.

+

WandB

Weights & Biases experiment tracking platform.

+

Console

Simple text-based logging to standard output.

+
+ +--- + +## Quick Start with SwanLab + +### Step 1: Configure SwanLab + +Simply set the logger backend to `swanlab` in your YAML configuration: + +```yaml title="config.yaml" +ajet: + trainer_common: + logger: swanlab +``` + +### Step 2: Start Training + +Launch your training as usual: + +```bash +ajet --conf tutorial/example_math_agent/math_agent.yaml +``` + +### Step 3: View Training Curves + +!!! success "Automatic Tracking" + Once training starts, SwanLab will automatically: + + 1. Track key metrics (reward, success rate, loss, etc.) + 2. Generate real-time training curves + 3. Provide a web dashboard for visualization + +You can access the SwanLab dashboard through the URL printed in the training logs. + +--- + +## Understanding Training Curves + +### Key Metrics to Monitor + +| Metric | Description | +|--------|-------------| +| **Reward** | Average reward per episode, indicating task performance | +| **Success Rate** | Percentage of successfully completed tasks | +| **Loss** | Training loss from the policy optimization algorithm | +| **Response Length** | Average length of model responses | +| **KL Divergence** | Divergence between current and reference policy | + +### Interpreting the Curves + +**Example Training Curve:** + +![Example Training Curve](https://img.alicdn.com/imgextra/i4/O1CN01gzwgLq1fkCnauydEu_!!6000000004044-2-tps-1422-550.png) + +A typical reward curve shows: + +| Phase | Description | +|-------|-------------| +| **Initial** | Reward may be low or unstable as the model explores | +| **Learning** | Reward gradually increases as the model learns better strategies | +| **Convergence** | Reward plateaus when the model reaches optimal performance | + +!!! tip "What to Look For" + - **Rising trend**: Indicates successful learning + - **Plateaus**: May indicate convergence or need for hyperparameter adjustment + - **Sudden drops**: Could signal instability or overfitting + +--- + +## Best Practices + +### Monitor Multiple Runs + +Compare different hyperparameter settings by running multiple experiments and comparing their curves side-by-side. + +### Set Appropriate Logging Frequency + +Balance between logging detail and training overhead: + +```yaml title="config.yaml" +ajet: + trainer_common: + log_freq: 1 # Log every N steps +``` + +### Save Checkpoints at Key Points + +Configure checkpoint saving to preserve models at peak performance: + +```yaml title="config.yaml" +ajet: + trainer_common: + save_freq: 100 # Save every 100 steps +``` + +--- + +## Next Steps + + diff --git a/docs/en/workflow.md b/docs/en/workflow.md new file mode 100644 index 00000000..1137f02c --- /dev/null +++ b/docs/en/workflow.md @@ -0,0 +1,256 @@ +# Trainable Workflow + +This tutorial introduces how to define a trainable workflow. + +!!! info "" + AgentJet provides two **convenient** and **mutually compatible** ways to wrap your Workflow: + + - **Simple**: Emphasizes simplicity, ease of use, and readability + - **Advanced**: Emphasizes flexibility, controllability, and extensibility + +In this article we use **AgentScope** framework for demonstration. For other frameworks (OpenAI SDK, Langchain, HTTP Requests), please follow the same pattern. + + + +
+AgentJet +
+ +## Simple Practice + +!!! Example "Simple Practice Abstract" + - Simply set `model` argument in AgentScope ReActAgent argument to `tuner.as_agentscope_model()` when initializing your agent. + - Wrap your code with `class MyWorkflow(Workflow)` and your agent is ready to be tuned. + +### 1. When to Use This Simple Practice + +!!! warning "Choose Simple Practice If You..." + - Know exactly which agents should be trained, or the number of agents is small + - Already finished basic debugging of your workflow + - Do not need to change which agents are trained on the fly + + +### 2. Convert Your Workflow to AgentJet Trainable Workflow + +The very first step is to create a class as a container to wrap your code: + +=== "`converted_workflow.py` - AgentJet Workflow" + + ```python + from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + class MyWorkflow(Workflow): + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # ... your ReActAgent workflow here ✈️ ... + return WorkflowOutput(reward=..., metadata={...}) + + ``` + + +Next, use the `tuner` argument, call its `tuner.as_agentscope_model()` method: + +=== "Before" + + ```python + model = DashScopeChatModel(model_name="qwen-max", stream=False) # ✈️ change here + agent_instance = ReActAgent( + name=f"Friday", + sys_prompt="You are a helpful assistant", + model=model, + formatter=DashScopeChatFormatter(), + ) + ``` + +=== "After" + + ```python + model = tuner.as_agentscope_model() # ✈️ change here + agent_instance = ReActAgent( + name=f"Friday", + sys_prompt="You are a helpful assistant", + model=model, + formatter=DashScopeChatFormatter(), + ) + ``` + +!!! warning "AjetTuner" + `AjetTuner` also has `.as_raw_openai_sdk_client()` and `.as_oai_baseurl_apikey()` method. But `.as_agentscope_model()` is more convenient for AgentScope agent workflow. + + + +### 3. Code Example + + + + + + + +## Advanced Practice + +!!! Example "Advanced Practice Abstract" + - The `tuner.as_agentscope_model()` function has hidden parameters, please further complete them to tell AgentJet the identity of agents. + - The `ajet.Workflow` class has hidden attribute `trainable_targets`, please assign it manually to narrow down agents to be tuned. + +### 1. When to Use Advanced Practice + +When designing a **multi-agent collaborative** workflow where each agent plays a different **target_tag**, AgentJet provides enhanced training and debugging capabilities. + +!!! warning "Multi-Agent Benefits" + With a multi-agent setup, you can: + + - **Precisely control** which agents are fine-tuned + - Explicitly define the default model for agents **not being trained** + - Switch trainable targets on the fly **without modifying** source code + +### 1. How to promote to advanced agent scenario: + +Simple, there are only two more issues that should be take care of in addition: + +i. **`.as_agentscope_model` has three hidden (optional) parameters, complete them for each agent.** + +| parameter | explanation | +|----------|------------| +| `agent_name` | The name of this agent | +| `target_tag` | A tag that mark the agent category | +| `debug_model` | The model used when this agent is not being tuned | + +=== "`as_agentscope_model()` parameters" + + ```python + model_for_an_agent = tuner.as_agentscope_model( + agent_name="AgentFriday", # the name of this agent + target_tag="Agent_Type_1", # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent. + debug_model=OpenAIChatModel( + model_name="Qwen/Qwen3-235B-A22B-Instruct-2507", + stream=False, + api_key="api_key", + ), # the model used when this agent is not in `self.trainable_targets` + ) + ``` + +ii. **`Workflow` has a hidden (optional) attribute called `trainable_targets`, config it.** + +| `trainable_targets` value | explanation | +|----------|------------| +| `trainable_targets = None` | All agents using `as_agentscope_model` will be trained | +| `trainable_targets = ["Agent_Type_1", "Agent_Type_2"]` | Agents with `target_tag=Agent_Type_1`, `target_tag=Agent_Type_2`, ... will be trained | +| `trainable_targets = []` | Illegal, no agents are trained | + + +| Scenario | Model Used | +|----------|------------| +| `target_tag` in `trainable_targets` | Trainable model | +| `target_tag` NOT in `trainable_targets` | Registered `debug_model` | + + + +!!! warning + Regardless of `target_tag` differences, all agents share a single model instance (one model weight to play different roles, the model receives different perceptions when playing different roles). + + +### 2. Multi-Agent Example + +Here's a complete example with multiple agent roles (Werewolves game): + +=== "`tutorial/example_werewolves/start.py`" + ```python + class ExampleWerewolves(Workflow): + trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.") + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + + # ensure trainable targets is legal + assert self.trainable_targets is not None, "trainable_targets cannot be None in ExampleWerewolves (because we want to demonstrate a explicit multi-agent case)." + + # bad guys and good guys cannot be trained simultaneously + # (because mix-cooperation-competition MARL needs too many advanced techniques to be displayed here) + if "werewolf" in self.trainable_targets: + assert len(self.trainable_targets) == 1, "Cannot train hostile roles simultaneously." + else: + assert len(self.trainable_targets) != 0, "No trainable targets specified." + + # make and shuffle roles (fix random seed for reproducibility) + roles = ["werewolf"] * 3 + ["villager"] * 3 + ["seer", "witch", "hunter"] + task_id = workflow_task.task.metadata["random_number"] + np.random.seed(int(task_id)) + np.random.shuffle(roles) + + # initialize agents + players = [] + for i, role in enumerate(roles): + default_model = OpenAIChatModel( + model_name="Qwen/Qwen3-235B-A22B-Instruct-2507", + stream=False, + api_key="no_api_key", + ) + model_for_this_agent = tuner.as_agentscope_model( + agent_name=f"Player{i + 1}", # the name of this agent + target_tag=role, # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent. + debug_model=default_model, # the model used when this agent is not in `self.trainable_targets` + ) + agent = ReActAgent( + name=f"Player{i + 1}", + sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), + model=model_for_this_agent, + formatter=DashScopeMultiAgentFormatter() + if role in self.trainable_targets + else OpenAIMultiAgentFormatter(), + max_iters=3 if role in self.trainable_targets else 5, + ) + # agent.set_console_output_enabled(False) + players += [agent] + + # reward condition + try: + good_guy_win = await werewolves_game(players, roles) + raw_reward = 0 + is_success = False + if (good_guy_win and self.trainable_targets[0] != "werewolf") or ( + not good_guy_win and self.trainable_targets[0] == "werewolf" + ): + raw_reward = 1 + is_success = True + logger.warning(f"Raw reward: {raw_reward}") + logger.warning(f"Is success: {is_success}") + except BadGuyException as e: + logger.bind(exception=True).exception( + f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent)." + ) + raw_reward = -0.1 + is_success = False + except Exception as e: + logger.bind(exception=True).exception( + f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent)." + ) + raw_reward = -0.1 + is_success = False + + return WorkflowOutput(reward=raw_reward, is_success=is_success) + ``` + +!!! tip "Configuration Flexibility" + In this example: + + - `role` describes an agent's in-game identity (werewolf, villager, etc.) + - `chosen_model` defines the default model when the role is not being trained + - You can flexibly switch training targets by modifying `trainable_targets` + + +## TinkerJet + +Wrapping and training your agent on a machine without GPU. + +Working in progress and coming soon. + + +## Next Steps + + diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..60963908 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,171 @@ +# AgentJet + + +
+AgentJet +
+ +**AgentJet (AJet)** is a cutting-edge, user-friendly agent tuning framework designed to optimize LLM models and agent workflows. + + +## ✈️ Key Features + +
+ +
+

Get Started with Ease

+
+

+ AgentJet simplifies the process of tuning the models that power your agent workflows. It supports nearly all major agent frameworks (e.g. agentscope, langchain), as well as framwork-less agents built from HTTP requests. +

+
+ +
+

Rich Tutorial Library

+
+

+ Rich examples as beginner's tutorial: math agent, werewolves rpg, appworld ... All with step-by-step + guides. Covering various agentic frameworks.

+
+ +
+

Reliable and Reproducible

+
+

+ Checkout AgentJet's community-powered, robot-assisted open-benchmarking system. + Share progress, compare training backbones, discover bugs and iterate faster than ever! + Click here to see AgentJet performance across tasks/versions/backbones. +

+
+ +
+

Multi-agent and Multi-turn

+
+

+ Built to support advanced multi-agent and multi-turn LLM workflows, + AgentJet intergrates timeline-merging algorithms that + automatically analyze and consolidate each agent's LLM timeline, + accelerating training speed 1.5x ~ 10x. +

+
+ +
+

High Resolution Logging

+
+

+ Log token-level rollout details, capturing token IDs, token loss masks, and token log probabilities with web UI display. This Support workflow development, agent diagnostics, and facilitate research on advanced LLM algorithm studies. +

+
+ +
+

Any Training Engine

+
+

+ Support multiple training engines as backbone (VeRL and Trinity-RFT). Tinker backbone support will be released soon. + Choose from vLLM and SGLang as you wish. Say goodbye to training engine gaps. +

+
+
+ + + +## ✈️ Quick Start + + + +We recommend using `uv` for dependency management. [Click here](en/installation.md) for details and other training backbone (e.g. Trinity-RFT) options. + +- Clone the Repository: + ```bash + git clone https://github.com/modelscope/AgentJet.git + cd AgentJet + ``` + +- Set up Environment: + ```bash + uv venv --python=3.10.16 && source .venv/bin/activate + uv pip install -e .[verl] + + # Note: flash-attn must be installed after other dependencies + uv pip install flash_attn==2.8.3 --no-build-isolation --no-cache-dir + ``` + +- Train the First Agent: + ```bash + # You can start training your first agent with a single command using a pre-configured YAML file + + ajet --conf tutorial/example_math_agent/math_agent.yaml + ``` + + + + +## ✈️ Example Library {#example-library} + +Explore our rich library of examples to kickstart your journey: + + + + +## ✈️ Core Concepts + +AgentJet makes agent fine-tuning straightforward by separating the developer interface from the internal execution logic. + +**✈️ The User-Centric Interface** + +To optimize an agent, you provide three core inputs: + + + +
+AgentJet Architecture +
+ +**✈️ Internal System Architecture** + +The internal system orchestrates several specialized modules to handle the complexities of RL training and agent interactions. + +| Module | Description | +|--------|-------------| +| **Launcher** | Manages background service processes (Ray, vLLM) and routes the backbone | +| **Task Rollout** | Bridges LLM engines and manages the Gym environment lifecycle | +| **Task Runner** | Executes the AgentScope workflow and calculates rewards | +| **Model Tuner** | Forwards inference requests from the workflow to the LLM engine | +| **Context Tracker** | Monitors LLM calls and automatically merges shared-history timelines (1.5x-10x efficiency boost) | + + + + +## ✈️ Next Steps + + + + + diff --git a/docs/javascripts/animations.js b/docs/javascripts/animations.js new file mode 100644 index 00000000..a5dc584a --- /dev/null +++ b/docs/javascripts/animations.js @@ -0,0 +1,401 @@ +/** + * Animations & Visual Enhancements JavaScript + * Phase 3: 视觉增强 + * + * Features: + * - Scroll-triggered animations + * - Image lazy loading complete handler + * - Copy button animations + * - Smooth scroll behaviors + */ + +(function() { + 'use strict'; + + // ======================================== + // Configuration + // ======================================== + + const config = { + scrollThreshold: 0.1, // 10% of element visible triggers animation + observerOptions: { + root: null, + rootMargin: '0px', + threshold: 0.1 + } + }; + + // ======================================== + // Scroll Animations + // ======================================== + + /** + * Initialize Intersection Observer for scroll animations + */ + function initScrollAnimations() { + // Check if browser supports IntersectionObserver + if (!('IntersectionObserver' in window)) { + console.log('IntersectionObserver not supported, skipping scroll animations'); + return; + } + + // Select elements to animate on scroll + const animateElements = document.querySelectorAll('.fade-in-on-scroll, .slide-in-left, .slide-in-right'); + + if (animateElements.length === 0) return; + + const observer = new IntersectionObserver((entries) => { + entries.forEach(entry => { + if (entry.isIntersecting) { + entry.target.classList.add('visible'); + // Optionally unobserve after animation + // observer.unobserve(entry.target); + } + }); + }, config.observerOptions); + + animateElements.forEach(el => observer.observe(el)); + } + + // ======================================== + // Image Loading + // ======================================== + + /** + * Handle image lazy loading completion + */ + function initImageAnimations() { + const lazyImages = document.querySelectorAll('img[loading="lazy"]'); + + lazyImages.forEach(img => { + // If image is already loaded + if (img.complete) { + img.classList.add('loaded'); + } else { + // Wait for image to load + img.addEventListener('load', function() { + this.classList.add('loaded'); + }); + + // Handle load errors + img.addEventListener('error', function() { + console.warn('Failed to load image:', this.src); + this.classList.add('loaded'); // Remove shimmer even on error + }); + } + }); + } + + // ======================================== + // Code Block Enhancements + // ======================================== + + /** + * Add language badges to code blocks + */ + function addCodeLanguageBadges() { + const codeBlocks = document.querySelectorAll('pre code[class*="language-"]'); + + codeBlocks.forEach(code => { + const parentPre = code.closest('pre'); + if (!parentPre || parentPre.querySelector('.language-name')) return; + + // Extract language from class + const languageClass = Array.from(code.classList).find(cls => cls.startsWith('language-')); + if (!languageClass) return; + + const language = languageClass.replace('language-', ''); + + // Create badge + const badge = document.createElement('span'); + badge.className = 'language-name'; + badge.textContent = language; + + // Add to parent pre + parentPre.style.position = 'relative'; + parentPre.appendChild(badge); + }); + } + + /** + * Enhanced copy button behavior + */ + function initCopyButtonAnimations() { + // Listen for copy events on the document + document.addEventListener('click', function(e) { + const copyButton = e.target.closest('.copy-button, .md-clipboard, [data-clipboard-target]'); + if (!copyButton) return; + + // Add copied class for animation + copyButton.classList.add('copied'); + + // Optional: Change button text temporarily + const originalText = copyButton.textContent; + if (originalText && !copyButton.querySelector('svg')) { + copyButton.textContent = '✓ Copied!'; + } + + // Remove after animation + setTimeout(() => { + copyButton.classList.remove('copied'); + if (originalText && !copyButton.querySelector('svg')) { + copyButton.textContent = originalText; + } + }, 2000); + }); + } + + // ======================================== + // Smooth Scroll + // ======================================== + + /** + * Smooth scroll to anchor links + */ + function initSmoothScroll() { + document.addEventListener('click', function(e) { + const link = e.target.closest('a[href^="#"]'); + if (!link) return; + + const targetId = link.getAttribute('href').slice(1); + if (!targetId) return; + + const targetElement = document.getElementById(targetId); + if (!targetElement) return; + + e.preventDefault(); + + targetElement.scrollIntoView({ + behavior: 'smooth', + block: 'start' + }); + + // Update URL without jumping + if (history.pushState) { + history.pushState(null, null, `#${targetId}`); + } + }); + } + + // ======================================== + // Reduced Motion Preference + // ======================================== + + /** + * Respect user's reduced motion preference + */ + function handleReducedMotion() { + const prefersReducedMotion = window.matchMedia('(prefers-reduced-motion: reduce)'); + + function applyReducedMotion(e) { + if (e.matches) { + document.documentElement.style.setProperty('--rm-transition-fast', '0.01ms'); + document.documentElement.style.setProperty('--rm-transition-normal', '0.01ms'); + document.documentElement.style.setProperty('--rm-transition-slow', '0.01ms'); + } else { + document.documentElement.style.setProperty('--rm-transition-fast', '0.15s'); + document.documentElement.style.setProperty('--rm-transition-normal', '0.25s'); + document.documentElement.style.setProperty('--rm-transition-slow', '0.4s'); + } + } + + // Initial check + applyReducedMotion(prefersReducedMotion); + + // Listen for changes + prefersReducedMotion.addEventListener('change', applyReducedMotion); + } + + // ======================================== + // Tab Switching Enhancements + // ======================================== + + /** + * Add smooth transitions to tab content + */ + function enhanceTabSwitching() { + const tabInputs = document.querySelectorAll('.tabbed-set input[type="radio"]'); + + tabInputs.forEach(input => { + input.addEventListener('change', function() { + const tabbedSet = this.closest('.tabbed-set'); + if (!tabbedSet) return; + + const activeBlock = tabbedSet.querySelector('.tabbed-block--active'); + if (activeBlock) { + // Add fade-out animation to old content + activeBlock.style.animation = 'fadeOut 0.15s ease-out'; + + setTimeout(() => { + activeBlock.style.animation = ''; + }, 150); + } + }); + }); + } + + // ======================================== + // Collapsible Details Enhancement + // ======================================== + + /** + * Enhance details/summary elements + */ + function enhanceDetails() { + const detailsElements = document.querySelectorAll('details'); + + detailsElements.forEach(details => { + details.addEventListener('toggle', function() { + if (this.open) { + // Add expand animation + const content = Array.from(this.children).find(el => el.tagName !== 'SUMMARY'); + if (content) { + content.style.animation = 'slideDown 0.25s ease-out'; + } + } + }); + }); + } + + // ======================================== + // Navigation Enhancements + // ======================================== + + /** + * Add active indicator animations to navigation + */ + function enhanceNavigation() { + // Highlight current page in navigation + const currentPath = window.location.pathname; + const navLinks = document.querySelectorAll('.md-nav__link, nav a'); + + navLinks.forEach(link => { + const linkPath = new URL(link.href, window.location.origin).pathname; + + if (linkPath === currentPath) { + link.classList.add('active'); + link.setAttribute('aria-current', 'page'); + + // Ensure parent items are expanded + let parent = link.closest('.md-nav__item--nested, li.has-children'); + while (parent) { + const toggle = parent.querySelector('input[type="checkbox"], .md-nav__toggle'); + if (toggle) { + toggle.checked = true; + } + parent = parent.parentElement.closest('.md-nav__item--nested, li.has-children'); + } + } + }); + } + + // ======================================== + // Performance: Debounce utility + // ======================================== + + function debounce(func, wait) { + let timeout; + return function executedFunction(...args) { + const later = () => { + clearTimeout(timeout); + func(...args); + }; + clearTimeout(timeout); + timeout = setTimeout(later, wait); + }; + } + + // ======================================== + // Scroll Progress Indicator (Optional) + // ======================================== + + /** + * Add reading progress bar to top of page + */ + function initScrollProgress() { + // Check if progress bar element exists + let progressBar = document.querySelector('.scroll-progress'); + + if (!progressBar) { + // Create progress bar + progressBar = document.createElement('div'); + progressBar.className = 'scroll-progress'; + progressBar.style.cssText = ` + position: fixed; + top: 0; + left: 0; + width: 0%; + height: 3px; + background: var(--primary, #3b82f6); + z-index: 9999; + transition: width 0.1s ease-out; + `; + document.body.appendChild(progressBar); + } + + const updateProgress = debounce(() => { + const windowHeight = window.innerHeight; + const documentHeight = document.documentElement.scrollHeight - windowHeight; + const scrolled = window.scrollY; + const progress = (scrolled / documentHeight) * 100; + + progressBar.style.width = `${Math.min(progress, 100)}%`; + }, 10); + + window.addEventListener('scroll', updateProgress); + updateProgress(); // Initial call + } + + // ======================================== + // Initialization + // ======================================== + + /** + * Initialize all animations when DOM is ready + */ + function init() { + console.log('🎨 Initializing OpenJudge animations...'); + + // Core animations + handleReducedMotion(); + initScrollAnimations(); + initImageAnimations(); + initSmoothScroll(); + + // UI enhancements + addCodeLanguageBadges(); + initCopyButtonAnimations(); + enhanceTabSwitching(); + enhanceDetails(); + enhanceNavigation(); + + // Optional: Enable scroll progress + // initScrollProgress(); + + console.log('✨ Animations initialized successfully'); + } + + // Run on DOM ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', init); + } else { + // DOM is already ready + init(); + } + + // Re-initialize on page navigation (for SPA-like behavior) + if ('navigation' in window && 'addEventListener' in window.navigation) { + window.navigation.addEventListener('navigate', () => { + setTimeout(init, 100); + }); + } + + // Expose utilities to global scope (optional) + window.OpenJudgeAnimations = { + debounce, + initScrollAnimations, + initImageAnimations, + addCodeLanguageBadges + }; + +})(); diff --git a/docs/javascripts/code-copy.js b/docs/javascripts/code-copy.js new file mode 100644 index 00000000..9f66108f --- /dev/null +++ b/docs/javascripts/code-copy.js @@ -0,0 +1,174 @@ +/** + * Code Copy Button - Universal code block copy functionality + * Adds copy buttons to all code blocks (non-tabbed) + */ + +(function() { + 'use strict'; + + function initCodeCopyButtons() { + // Find all code blocks that don't already have a copy button + // Exclude tabbed code blocks (handled by tabbed-code.js) + const codeBlocks = document.querySelectorAll('article pre, .prose pre, .md-typeset pre'); + + codeBlocks.forEach(function(preElement) { + // Skip if already has a copy button + if (preElement.querySelector('.copy-button')) { + return; + } + + // Skip if it's inside a tabbed set + if (preElement.closest('.tabbed-set')) { + return; + } + + // Skip if it's a tabbed block + if (preElement.classList.contains('tabbed-block') || preElement.closest('.tabbed-block')) { + return; + } + + // Create copy button + const copyButton = document.createElement('button'); + copyButton.className = 'copy-button code-copy-btn'; + copyButton.innerHTML = ''; + copyButton.setAttribute('aria-label', 'Copy code'); + copyButton.setAttribute('title', 'Copy code'); + + // Add click handler + copyButton.addEventListener('click', function(e) { + e.preventDefault(); + e.stopPropagation(); + + // Get code content + const codeElement = preElement.querySelector('code'); + if (!codeElement) return; + + const code = codeElement.textContent; + + // Use modern clipboard API + if (navigator.clipboard && window.isSecureContext) { + navigator.clipboard.writeText(code).then(function() { + showCopiedState(copyButton); + }).catch(function(err) { + console.error('Failed to copy:', err); + fallbackCopyTextToClipboard(code, copyButton); + }); + } else { + // Fallback for older browsers or non-secure contexts + fallbackCopyTextToClipboard(code, copyButton); + } + }); + + // Insert copy button into pre element + preElement.style.position = 'relative'; + preElement.appendChild(copyButton); + }); + + // Also handle .highlight wrapper (Pygments) + const highlightBlocks = document.querySelectorAll('article .highlight, .prose .highlight, .md-typeset .highlight'); + + highlightBlocks.forEach(function(highlightElement) { + // Skip if already has a copy button + if (highlightElement.querySelector('.copy-button')) { + return; + } + + // Skip if it's inside a tabbed set + if (highlightElement.closest('.tabbed-set')) { + return; + } + + // Skip if it's a tabbed block + if (highlightElement.classList.contains('tabbed-block') || highlightElement.closest('.tabbed-block')) { + return; + } + + // Create copy button + const copyButton = document.createElement('button'); + copyButton.className = 'copy-button code-copy-btn'; + copyButton.innerHTML = ''; + copyButton.setAttribute('aria-label', 'Copy code'); + copyButton.setAttribute('title', 'Copy code'); + + // Add click handler + copyButton.addEventListener('click', function(e) { + e.preventDefault(); + e.stopPropagation(); + + // Get code content + const codeElement = highlightElement.querySelector('pre code') || highlightElement.querySelector('code'); + if (!codeElement) return; + + const code = codeElement.textContent; + + // Use modern clipboard API + if (navigator.clipboard && window.isSecureContext) { + navigator.clipboard.writeText(code).then(function() { + showCopiedState(copyButton); + }).catch(function(err) { + console.error('Failed to copy:', err); + fallbackCopyTextToClipboard(code, copyButton); + }); + } else { + // Fallback for older browsers or non-secure contexts + fallbackCopyTextToClipboard(code, copyButton); + } + }); + + // Insert copy button into highlight element + highlightElement.style.position = 'relative'; + highlightElement.appendChild(copyButton); + }); + } + + function fallbackCopyTextToClipboard(text, button) { + const textArea = document.createElement('textarea'); + textArea.value = text; + textArea.style.position = 'fixed'; + textArea.style.left = '-999999px'; + textArea.style.top = '-999999px'; + document.body.appendChild(textArea); + textArea.focus(); + textArea.select(); + + try { + const successful = document.execCommand('copy'); + if (successful) { + showCopiedState(button); + } + } catch (err) { + console.error('Fallback: Failed to copy', err); + } + + document.body.removeChild(textArea); + } + + function showCopiedState(button) { + const originalHTML = button.innerHTML; + button.innerHTML = ''; + button.classList.add('copied'); + + setTimeout(function() { + button.innerHTML = originalHTML; + button.classList.remove('copied'); + }, 2000); + } + + // Run on DOM ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', initCodeCopyButtons); + } else { + initCodeCopyButtons(); + } + + // Also re-initialize on navigation (for SPA-like behavior in MkDocs Material) + if (typeof document$ !== 'undefined') { + document$.subscribe(function() { + // Use setTimeout to ensure DOM is fully updated + setTimeout(initCodeCopyButtons, 100); + }); + } + + // Export for manual re-initialization if needed + window.initCodeCopyButtons = initCodeCopyButtons; +})(); diff --git a/docs/javascripts/code-zoom.js b/docs/javascripts/code-zoom.js new file mode 100644 index 00000000..22d3d624 --- /dev/null +++ b/docs/javascripts/code-zoom.js @@ -0,0 +1 @@ +/* Code zoom - placeholder */ diff --git a/docs/javascripts/nav-scroll-fix.js b/docs/javascripts/nav-scroll-fix.js new file mode 100644 index 00000000..3b295f8b --- /dev/null +++ b/docs/javascripts/nav-scroll-fix.js @@ -0,0 +1,141 @@ +/** + * Navigation Scroll Position Preservation + * + * This script preserves the sidebar scroll position when navigating between pages. + * Without this, clicking a link in the scrolled sidebar would reset it to the top. + */ + +(function() { + 'use strict'; + + const STORAGE_KEY = 'open_judge-sidebar-scroll'; + const SIDEBAR_SELECTORS = [ + '[data-slot="sidebar-content"]', // Current theme's sidebar container + '[data-sidebar="content"]', // Alternative selector + '.md-sidebar--primary', // MkDocs Material theme + 'nav.sidebar', + '.md-sidebar', + '.nav-sidebar', + 'aside.sidebar' + ]; + + /** + * Get the primary sidebar element + */ + function getSidebar() { + for (const selector of SIDEBAR_SELECTORS) { + const sidebar = document.querySelector(selector); + if (sidebar) { + return sidebar; + } + } + return null; + } + + /** + * Restore scroll position instantly without smooth scrolling flicker. + */ + function setScrollTopInstant(sidebar, position) { + if (!sidebar) return; + const originalBehavior = sidebar.style.scrollBehavior; + sidebar.style.scrollBehavior = 'auto'; + sidebar.scrollTop = position; + // Restore original behavior on next frame to keep smooth scrolling elsewhere. + requestAnimationFrame(() => { + if (originalBehavior) { + sidebar.style.scrollBehavior = originalBehavior; + } else { + sidebar.style.removeProperty('scroll-behavior'); + } + }); + } + + /** + * Save sidebar scroll position to sessionStorage + */ + function saveSidebarScroll() { + const sidebar = getSidebar(); + if (sidebar) { + try { + const scrollData = { + position: sidebar.scrollTop, + timestamp: Date.now() + }; + sessionStorage.setItem(STORAGE_KEY, JSON.stringify(scrollData)); + } catch (e) { + console.warn('Failed to save sidebar scroll position:', e); + } + } + } + + /** + * Restore sidebar scroll position from sessionStorage + */ + function restoreSidebarScroll() { + const sidebar = getSidebar(); + if (!sidebar) return; + + try { + const stored = sessionStorage.getItem(STORAGE_KEY); + if (stored) { + const scrollData = JSON.parse(stored); + + // Only restore if saved within the last 5 minutes + const age = Date.now() - scrollData.timestamp; + if (age < 5 * 60 * 1000) { + // Use requestAnimationFrame to ensure DOM is ready + requestAnimationFrame(() => { + setScrollTopInstant(sidebar, scrollData.position); + }); + } else { + // Clear old data + sessionStorage.removeItem(STORAGE_KEY); + } + } + } catch (e) { + console.warn('Failed to restore sidebar scroll position:', e); + } + } + + /** + * Initialize scroll position preservation + */ + function init() { + // Restore scroll position on page load + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', restoreSidebarScroll); + } else { + restoreSidebarScroll(); + } + + // Save scroll position before navigation + window.addEventListener('beforeunload', saveSidebarScroll); + + // Save scroll position when clicking sidebar links + document.addEventListener('click', function(e) { + const link = e.target.closest('a'); + if (!link) return; + + // Check if the link is inside the sidebar + const sidebar = getSidebar(); + if (sidebar && sidebar.contains(link)) { + // Save current scroll position + saveSidebarScroll(); + } + }); + + // Periodically save scroll position while user scrolls + const sidebar = getSidebar(); + if (sidebar) { + let scrollTimeout; + sidebar.addEventListener('scroll', function() { + clearTimeout(scrollTimeout); + scrollTimeout = setTimeout(saveSidebarScroll, 150); + }); + } + } + + // Initialize when script loads + init(); + +})(); diff --git a/docs/javascripts/responsive.js b/docs/javascripts/responsive.js new file mode 100644 index 00000000..d57c4db2 --- /dev/null +++ b/docs/javascripts/responsive.js @@ -0,0 +1,355 @@ +/** + * Responsive Enhancements for OpenJudge Documentation + * Phase 5: 响应式完善 + * + * Features: + * - Mobile menu toggle + * - Table scroll detection + * - Touch event optimization + * - Viewport resize handling + */ + +(function() { + 'use strict'; + + // ======================================== + // Mobile Navigation Toggle + // ======================================== + + function initMobileNav() { + const sidebar = document.querySelector('.md-sidebar--primary, nav.sidebar, .nav-sidebar'); + const menuToggle = document.querySelector('.mobile-menu-toggle'); + let overlay = document.querySelector('.mobile-nav-overlay'); + + // Create overlay if it doesn't exist + if (!overlay && sidebar) { + overlay = document.createElement('div'); + overlay.className = 'mobile-nav-overlay'; + document.body.appendChild(overlay); + } + + // Create menu toggle if it doesn't exist + if (!menuToggle && sidebar) { + const toggle = document.createElement('button'); + toggle.className = 'mobile-menu-toggle'; + toggle.setAttribute('aria-label', 'Toggle navigation menu'); + toggle.innerHTML = ` + + + + + + + + + + `; + document.body.appendChild(toggle); + + toggle.addEventListener('click', toggleMobileNav); + } + + if (overlay) { + overlay.addEventListener('click', closeMobileNav); + } + + // Close on escape key + document.addEventListener('keydown', function(e) { + if (e.key === 'Escape') { + closeMobileNav(); + } + }); + + // Close on resize to desktop + window.addEventListener('resize', debounce(function() { + if (window.innerWidth >= 768) { + closeMobileNav(); + } + }, 100)); + } + + function toggleMobileNav() { + const sidebar = document.querySelector('.md-sidebar--primary, nav.sidebar, .nav-sidebar'); + const menuToggle = document.querySelector('.mobile-menu-toggle'); + const overlay = document.querySelector('.mobile-nav-overlay'); + + if (sidebar) { + sidebar.classList.toggle('open'); + } + if (menuToggle) { + menuToggle.classList.toggle('active'); + } + if (overlay) { + overlay.classList.toggle('visible'); + } + + // Prevent body scroll when menu is open + document.body.classList.toggle('nav-open'); + } + + function closeMobileNav() { + const sidebar = document.querySelector('.md-sidebar--primary, nav.sidebar, .nav-sidebar'); + const menuToggle = document.querySelector('.mobile-menu-toggle'); + const overlay = document.querySelector('.mobile-nav-overlay'); + + if (sidebar) { + sidebar.classList.remove('open'); + } + if (menuToggle) { + menuToggle.classList.remove('active'); + } + if (overlay) { + overlay.classList.remove('visible'); + } + document.body.classList.remove('nav-open'); + } + + // ======================================== + // Table Scroll Detection + // ======================================== + + function initTableScroll() { + const tables = document.querySelectorAll('.table-responsive, table'); + + tables.forEach(function(table) { + let wrapper = table; + + // Wrap table if not already in a responsive container + if (!table.classList.contains('table-responsive') && table.tagName === 'TABLE') { + wrapper = document.createElement('div'); + wrapper.className = 'table-responsive'; + table.parentNode.insertBefore(wrapper, table); + wrapper.appendChild(table); + } + + // Check scroll state + updateTableScrollState(wrapper); + + // Listen for scroll + wrapper.addEventListener('scroll', function() { + updateTableScrollState(wrapper); + }); + }); + + // Update on resize + window.addEventListener('resize', debounce(function() { + document.querySelectorAll('.table-responsive').forEach(updateTableScrollState); + }, 100)); + } + + function updateTableScrollState(wrapper) { + if (!wrapper) return; + + const scrollLeft = wrapper.scrollLeft; + const scrollWidth = wrapper.scrollWidth; + const clientWidth = wrapper.clientWidth; + + // Check if table is scrollable + const canScroll = scrollWidth > clientWidth; + + // Update classes + wrapper.classList.toggle('can-scroll', canScroll); + wrapper.classList.toggle('can-scroll-left', scrollLeft > 0); + wrapper.classList.toggle('can-scroll-right', scrollLeft < scrollWidth - clientWidth - 1); + } + + // ======================================== + // Touch Event Optimization + // ======================================== + + function initTouchOptimization() { + // Detect touch device + const isTouchDevice = 'ontouchstart' in window || navigator.maxTouchPoints > 0; + + if (isTouchDevice) { + document.body.classList.add('touch-device'); + + // Fast tap for navigation links + const navLinks = document.querySelectorAll('.md-nav__link, nav a'); + navLinks.forEach(function(link) { + link.addEventListener('touchend', function(e) { + // Prevent double-tap zoom on navigation + if (e.target.tagName === 'A') { + e.preventDefault(); + window.location.href = e.target.href; + } + }); + }); + } else { + document.body.classList.add('pointer-device'); + } + } + + // ======================================== + // Viewport Height Fix (Mobile Safari) + // ======================================== + + function initViewportFix() { + // Fix for mobile viewport height (100vh issue) + function setViewportHeight() { + const vh = window.innerHeight * 0.01; + document.documentElement.style.setProperty('--vh', `${vh}px`); + } + + setViewportHeight(); + + window.addEventListener('resize', debounce(setViewportHeight, 100)); + window.addEventListener('orientationchange', function() { + setTimeout(setViewportHeight, 100); + }); + } + + // ======================================== + // Scroll Progress Indicator + // ======================================== + + function initScrollProgress() { + let progressBar = document.querySelector('.scroll-progress'); + + // Create progress bar if it doesn't exist + if (!progressBar) { + progressBar = document.createElement('div'); + progressBar.className = 'scroll-progress'; + document.body.prepend(progressBar); + } + + function updateProgress() { + const scrollTop = window.scrollY || document.documentElement.scrollTop; + const scrollHeight = document.documentElement.scrollHeight - window.innerHeight; + const progress = scrollHeight > 0 ? (scrollTop / scrollHeight) * 100 : 0; + + progressBar.style.width = `${progress}%`; + } + + window.addEventListener('scroll', throttle(updateProgress, 10)); + updateProgress(); + } + + // ======================================== + // Scroll to Top Button + // ======================================== + + function initScrollToTop() { + let scrollBtn = document.querySelector('.scroll-to-top'); + + // Create button if it doesn't exist + if (!scrollBtn) { + scrollBtn = document.createElement('button'); + scrollBtn.className = 'scroll-to-top'; + scrollBtn.setAttribute('aria-label', 'Scroll to top'); + scrollBtn.innerHTML = ` + + + + `; + document.body.appendChild(scrollBtn); + } + + function toggleButton() { + const scrollTop = window.scrollY || document.documentElement.scrollTop; + scrollBtn.classList.toggle('visible', scrollTop > 300); + } + + scrollBtn.addEventListener('click', function() { + window.scrollTo({ + top: 0, + behavior: 'smooth' + }); + }); + + window.addEventListener('scroll', throttle(toggleButton, 100)); + toggleButton(); + } + + // ======================================== + // Responsive Image Loading + // ======================================== + + function initResponsiveImages() { + // Lazy load images + const images = document.querySelectorAll('img[loading="lazy"]'); + + if ('IntersectionObserver' in window) { + const imageObserver = new IntersectionObserver(function(entries) { + entries.forEach(function(entry) { + if (entry.isIntersecting) { + const img = entry.target; + img.classList.add('loaded'); + imageObserver.unobserve(img); + } + }); + }, { + rootMargin: '50px 0px' + }); + + images.forEach(function(img) { + imageObserver.observe(img); + }); + } else { + // Fallback for browsers without IntersectionObserver + images.forEach(function(img) { + img.classList.add('loaded'); + }); + } + } + + // ======================================== + // Utility Functions + // ======================================== + + function debounce(func, wait) { + let timeout; + return function executedFunction() { + const context = this; + const args = arguments; + clearTimeout(timeout); + timeout = setTimeout(function() { + func.apply(context, args); + }, wait); + }; + } + + function throttle(func, limit) { + let inThrottle; + return function() { + const context = this; + const args = arguments; + if (!inThrottle) { + func.apply(context, args); + inThrottle = true; + setTimeout(function() { + inThrottle = false; + }, limit); + } + }; + } + + // ======================================== + // Initialize All + // ======================================== + + function init() { + initMobileNav(); + initTableScroll(); + initTouchOptimization(); + initViewportFix(); + initScrollProgress(); + initScrollToTop(); + initResponsiveImages(); + } + + // Run on DOM ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', init); + } else { + init(); + } + + // Expose functions for external use + window.OpenJudgeResponsive = { + toggleMobileNav: toggleMobileNav, + closeMobileNav: closeMobileNav, + updateTableScrollState: updateTableScrollState + }; + +})(); diff --git a/docs/javascripts/search-fix.js b/docs/javascripts/search-fix.js new file mode 100644 index 00000000..444f2af9 --- /dev/null +++ b/docs/javascripts/search-fix.js @@ -0,0 +1 @@ +/* Search fix - placeholder */ diff --git a/docs/javascripts/tabbed-code.js b/docs/javascripts/tabbed-code.js new file mode 100644 index 00000000..cfd19559 --- /dev/null +++ b/docs/javascripts/tabbed-code.js @@ -0,0 +1,176 @@ +/** + * Tabbed Code Blocks - JavaScript Enhancement for shadcn/ui theme + * Provides fallback functionality for pymdownx.tabbed alternate_style + */ + +(function() { + 'use strict'; + + function initTabbedSets() { + // Find all tabbed sets + const tabbedSets = document.querySelectorAll('.tabbed-set.tabbed-alternate'); + + tabbedSets.forEach(function(tabbedSet) { + const inputs = tabbedSet.querySelectorAll(':scope > input[type="radio"]'); + const labels = tabbedSet.querySelectorAll(':scope > .tabbed-labels > label'); + const blocks = tabbedSet.querySelectorAll(':scope > .tabbed-content > .tabbed-block'); + + // Function to update active state + function updateActiveState() { + let activeIndex = 0; + + // Find which input is checked + inputs.forEach(function(input, index) { + if (input.checked) { + activeIndex = index; + } + }); + + // Update labels + labels.forEach(function(label, index) { + if (index === activeIndex) { + label.classList.add('tabbed-label--active'); + label.setAttribute('data-active', 'true'); + } else { + label.classList.remove('tabbed-label--active'); + label.setAttribute('data-active', 'false'); + } + }); + + // Update content blocks + blocks.forEach(function(block, index) { + if (index === activeIndex) { + block.style.display = 'block'; + block.classList.add('tabbed-block--active'); + } else { + block.style.display = 'none'; + block.classList.remove('tabbed-block--active'); + } + }); + } + + // Listen for changes on radio inputs + inputs.forEach(function(input) { + input.addEventListener('change', updateActiveState); + }); + + // Also handle label clicks directly (backup for CSS label-for behavior) + labels.forEach(function(label, index) { + label.addEventListener('click', function(e) { + if (inputs[index]) { + inputs[index].checked = true; + // Trigger change event + inputs[index].dispatchEvent(new Event('change')); + } + }); + }); + + // Initialize state + updateActiveState(); + + // Add copy button to tabbed code blocks + addCopyButtonToTabbedSet(tabbedSet); + }); + } + + function addCopyButtonToTabbedSet(tabbedSet) { + // Check if copy button already exists + if (tabbedSet.querySelector('.copy-button')) { + return; + } + + // Find the labels container + const labelsContainer = tabbedSet.querySelector('.tabbed-labels'); + if (!labelsContainer) return; + + // Create copy button + const copyButton = document.createElement('button'); + copyButton.className = 'copy-button'; + copyButton.innerHTML = ''; + copyButton.setAttribute('aria-label', 'Copy code'); + copyButton.setAttribute('title', 'Copy code'); + + // Add click handler + copyButton.addEventListener('click', function(e) { + e.preventDefault(); + e.stopPropagation(); + + // Find the active code block + const activeBlock = tabbedSet.querySelector('.tabbed-block--active'); + if (!activeBlock) return; + + // Get code content + const codeElement = activeBlock.querySelector('pre code') || activeBlock.querySelector('code'); + if (!codeElement) return; + + // Copy to clipboard + const code = codeElement.textContent; + + // Use modern clipboard API + if (navigator.clipboard && window.isSecureContext) { + navigator.clipboard.writeText(code).then(function() { + showCopiedState(copyButton); + }).catch(function(err) { + console.error('Failed to copy:', err); + fallbackCopyTextToClipboard(code, copyButton); + }); + } else { + // Fallback for older browsers + fallbackCopyTextToClipboard(code, copyButton); + } + }); + + // Insert copy button into labels container + labelsContainer.appendChild(copyButton); + } + + function fallbackCopyTextToClipboard(text, button) { + const textArea = document.createElement('textarea'); + textArea.value = text; + textArea.style.position = 'fixed'; + textArea.style.left = '-999999px'; + textArea.style.top = '-999999px'; + document.body.appendChild(textArea); + textArea.focus(); + textArea.select(); + + try { + const successful = document.execCommand('copy'); + if (successful) { + showCopiedState(button); + } + } catch (err) { + console.error('Fallback: Failed to copy', err); + } + + document.body.removeChild(textArea); + } + + function showCopiedState(button) { + const originalHTML = button.innerHTML; + button.innerHTML = ''; + button.classList.add('copied'); + + setTimeout(function() { + button.innerHTML = originalHTML; + button.classList.remove('copied'); + }, 2000); + } + + // Run on DOM ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', initTabbedSets); + } else { + initTabbedSets(); + } + + // Also re-initialize on navigation (for SPA-like behavior) + if (typeof document$ !== 'undefined') { + document$.subscribe(function() { + initTabbedSets(); + }); + } + + // Export for manual re-initialization if needed + window.initTabbedSets = initTabbedSets; +})(); diff --git a/docs/logo.png b/docs/logo.png new file mode 100644 index 00000000..038c78bc Binary files /dev/null and b/docs/logo.png differ diff --git a/docs/logo.svg b/docs/logo.svg new file mode 100644 index 00000000..f4b49456 --- /dev/null +++ b/docs/logo.svg @@ -0,0 +1 @@ + diff --git a/docs/mini_test.py b/docs/mini_test.py new file mode 100644 index 00000000..d6c70073 --- /dev/null +++ b/docs/mini_test.py @@ -0,0 +1,12 @@ +from ajet import AgentJetJob +from tutorial.example_math_agent.math_agent_simplify import MathToolWorkflow + +model_path = "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-1___5B-Instruct" +job = AgentJetJob(backbone="trinity", n_gpu=2, n_gpu_for_infer=1, algorithm="grpo", model=model_path) +job.set_workflow(MathToolWorkflow, ensure_reward_in_workflow=True) +job.set_data(type="hf", dataset_path="openai/gsm8k") +# [Optional] job.dump_job_as_yaml('./saved_experiments/math.yaml') # Save yaml file for manual adjustment +# [Optional] job.load_job_from_yaml('./saved_experiments/math.yaml') # Load yaml file from manual adjustment + +# Equivalent to `ajet --conf ./saved_experiments/math.yaml` in the terminal +tuned_model = job.tune() diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..db4f637c --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,13 @@ +# MkDocs and theme +mkdocs==1.6.1 +mkdocs-shadcn==0.9.5 + +# Plugins +mkdocstrings==0.30.1 +mkdocstrings-python==1.18.2 + +# Markdown extensions (pymdownx is included in pymdown-extensions) +pymdown-extensions==10.16.1 + +# Syntax highlighting +Pygments>=2.18.0 diff --git a/docs/stylesheets/animations.css b/docs/stylesheets/animations.css new file mode 100644 index 00000000..9d390ff7 --- /dev/null +++ b/docs/stylesheets/animations.css @@ -0,0 +1,877 @@ +/* + * Animations & Visual Enhancements for OpenJudge Documentation + * Phase 3: 视觉增强 + * + * Features: + * - Keyframe Animations (fadeIn, slideUp, shimmer, pulse, etc.) + * - Page Load Effects + * - Hover Interactions + * - Transition Effects + * - Visual Polish (shadows, gradients, etc.) + */ + +/* ======================================== + CSS Variables - Animation System + ======================================== */ + +:root { + /* Animation Timing */ + --rm-transition-fast: 0.15s; + --rm-transition-normal: 0.25s; + --rm-transition-slow: 0.4s; + --rm-ease-smooth: cubic-bezier(0.4, 0, 0.2, 1); + --rm-ease-in: cubic-bezier(0.4, 0, 1, 1); + --rm-ease-out: cubic-bezier(0, 0, 0.2, 1); + --rm-ease-in-out: cubic-bezier(0.4, 0, 0.2, 1); + + /* Shadow System */ + --rm-shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05); + --rm-shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); + --rm-shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); + --rm-shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04); + + /* Hover Lift */ + --rm-lift-sm: translateY(-2px); + --rm-lift-md: translateY(-4px); +} + +/* Dark mode shadows */ +.dark, +.dark { + --rm-shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.3); + --rm-shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.4), 0 2px 4px -1px rgba(0, 0, 0, 0.3); + --rm-shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.5), 0 4px 6px -2px rgba(0, 0, 0, 0.3); + --rm-shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.6), 0 10px 10px -5px rgba(0, 0, 0, 0.4); +} + +/* ======================================== + Keyframe Animations + ======================================== */ + +/* Fade In */ +@keyframes fadeIn { + from { + opacity: 0; + } + to { + opacity: 1; + } +} + +/* Fade In Up */ +@keyframes fadeInUp { + from { + opacity: 0; + transform: translateY(20px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +/* Fade In Down */ +@keyframes fadeInDown { + from { + opacity: 0; + transform: translateY(-20px); + } + to { + opacity: 1; + transform: translateY(0); + } +} + +/* Slide Up */ +@keyframes slideUp { + from { + transform: translateY(10px); + opacity: 0; + } + to { + transform: translateY(0); + opacity: 1; + } +} + +/* Slide Down */ +@keyframes slideDown { + from { + transform: translateY(-10px); + opacity: 0; + } + to { + transform: translateY(0); + opacity: 1; + } +} + +/* Slide In From Left */ +@keyframes slideInLeft { + from { + transform: translateX(-20px); + opacity: 0; + } + to { + transform: translateX(0); + opacity: 1; + } +} + +/* Slide In From Right */ +@keyframes slideInRight { + from { + transform: translateX(20px); + opacity: 0; + } + to { + transform: translateX(0); + opacity: 1; + } +} + +/* Scale In */ +@keyframes scaleIn { + from { + transform: scale(0.95); + opacity: 0; + } + to { + transform: scale(1); + opacity: 1; + } +} + +/* Pulse */ +@keyframes pulse { + 0%, 100% { + opacity: 1; + } + 50% { + opacity: 0.7; + } +} + +/* Shimmer (Loading effect) */ +@keyframes shimmer { + 0% { + background-position: -1000px 0; + } + 100% { + background-position: 1000px 0; + } +} + +/* Spin */ +@keyframes spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } +} + +/* Bounce */ +@keyframes bounce { + 0%, 100% { + transform: translateY(0); + } + 50% { + transform: translateY(-10px); + } +} + +/* Wiggle */ +@keyframes wiggle { + 0%, 100% { + transform: rotate(0deg); + } + 25% { + transform: rotate(-3deg); + } + 75% { + transform: rotate(3deg); + } +} + +/* Expand (for collapsible sections) */ +@keyframes expand { + from { + max-height: 0; + opacity: 0; + } + to { + max-height: 2000px; + opacity: 1; + } +} + +/* Collapse */ +@keyframes collapse { + from { + max-height: 2000px; + opacity: 1; + } + to { + max-height: 0; + opacity: 0; + } +} + +/* Glow */ +@keyframes glow { + 0%, 100% { + box-shadow: 0 0 5px currentColor; + } + 50% { + box-shadow: 0 0 20px currentColor; + } +} + +/* ======================================== + Page Load Animations + ======================================== */ + +/* Main content fade in */ +article, +.md-content__inner, +main.md-main { + animation: fadeInUp 0.5s var(--rm-ease-out) forwards; +} + +/* Stagger animation for list items on load */ +article > *:nth-child(1) { animation-delay: 0ms; } +article > *:nth-child(2) { animation-delay: 50ms; } +article > *:nth-child(3) { animation-delay: 100ms; } +article > *:nth-child(4) { animation-delay: 150ms; } +article > *:nth-child(5) { animation-delay: 200ms; } + +/* Reduce motion for accessibility */ +@media (prefers-reduced-motion: reduce) { + *, + *::before, + *::after { + animation-duration: 0.01ms !important; + animation-iteration-count: 1 !important; + transition-duration: 0.01ms !important; + } + + article, + .md-content__inner, + main.md-main { + animation: none; + } +} + +/* ======================================== + Link Hover Effects + ======================================== */ + +article a:not(.button):not(.btn), +.prose a:not(.button):not(.btn), +.md-typeset a:not(.button):not(.btn) { + position: relative; + transition: color var(--rm-transition-fast) var(--rm-ease-smooth); +} + +/* Animated underline on hover */ +article a:not(.button):not(.btn)::after, +.prose a:not(.button):not(.btn)::after { + content: ''; + position: absolute; + left: 0; + bottom: -2px; + width: 0; + height: 1px; + background: currentColor; + transition: width var(--rm-transition-normal) var(--rm-ease-out); +} + +article a:not(.button):not(.btn):hover::after, +.prose a:not(.button):not(.btn):hover::after { + width: 100%; +} + +/* External link icon animation */ +article a[href^="http"]::after, +.prose a[href^="http"]::after { + display: inline-block; + transition: transform var(--rm-transition-fast) var(--rm-ease-smooth); +} + +article a[href^="http"]:hover::after, +.prose a[href^="http"]:hover::after { + transform: translate(2px, -2px); +} + +/* ======================================== + Button Hover Effects + ======================================== */ + +button, +.button, +.btn, +.md-button, +input[type="submit"], +input[type="button"] { + position: relative; + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); + cursor: pointer; +} + +button:hover, +.button:hover, +.btn:hover, +.md-button:hover { + transform: var(--rm-lift-sm); + box-shadow: var(--rm-shadow-md); +} + +button:active, +.button:active, +.btn:active, +.md-button:active { + transform: scale(0.98); + box-shadow: var(--rm-shadow-sm); +} + +/* Ripple effect on click */ +button::before, +.button::before, +.btn::before { + content: ''; + position: absolute; + top: 50%; + left: 50%; + width: 0; + height: 0; + border-radius: 50%; + background: rgba(255, 255, 255, 0.3); + transform: translate(-50%, -50%); + transition: width 0.6s, height 0.6s; +} + +button:active::before, +.button:active::before, +.btn:active::before { + width: 300px; + height: 300px; +} + +/* ======================================== + Code Block Hover Effects + ======================================== */ + +/* Code block container */ +.highlight, +.codehilite, +pre[class*="language-"], +div[class*="highlight-"] { + position: relative; + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); + box-shadow: var(--rm-shadow-sm); +} + +.highlight:hover, +.codehilite:hover, +pre[class*="language-"]:hover, +div[class*="highlight-"]:hover { + box-shadow: var(--rm-shadow-md); + transform: translateY(-1px); +} + +/* Copy button hover effect */ +.copy-button, +button[data-clipboard-target], +.md-clipboard { + opacity: 0; + transform: scale(0.9); + transition: all var(--rm-transition-fast) var(--rm-ease-smooth); +} + +.highlight:hover .copy-button, +.codehilite:hover .copy-button, +pre:hover .copy-button, +.highlight:hover .md-clipboard, +.codehilite:hover .md-clipboard, +pre:hover .md-clipboard { + opacity: 1; + transform: scale(1); +} + +.copy-button:hover, +.md-clipboard:hover { + transform: scale(1.1); + background: var(--primary, #3b82f6); + color: white; +} + +.copy-button:active, +.md-clipboard:active { + transform: scale(0.95); +} + +/* Copy success animation */ +.copy-button.copied, +.md-clipboard.copied { + animation: pulse 0.4s var(--rm-ease-smooth); +} + +/* Code language badge */ +.highlight > .language-name, +pre > .language-name, +.code-lang-badge { + position: absolute; + top: 0.5rem; + right: 0.5rem; + padding: 0.125rem 0.5rem; + font-size: 0.75rem; + font-weight: 500; + background: rgba(0, 0, 0, 0.6); + color: rgba(255, 255, 255, 0.9); + border-radius: 0.25rem; + text-transform: uppercase; + letter-spacing: 0.05em; + opacity: 0.7; + transition: opacity var(--rm-transition-fast) var(--rm-ease-smooth); + pointer-events: none; + backdrop-filter: blur(4px); +} + +.highlight:hover > .language-name, +pre:hover > .language-name, +.code-lang-badge:hover { + opacity: 1; +} + +/* ======================================== + Card & Container Effects + ======================================== */ + +/* Admonition hover effect */ +article .admonition, +.prose .admonition, +.md-typeset .admonition { + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); + box-shadow: var(--rm-shadow-sm); +} + +article .admonition:hover, +.prose .admonition:hover, +.md-typeset .admonition:hover { + box-shadow: var(--rm-shadow-md); + transform: translateY(-2px); +} + +/* Workflow steps hover */ +.workflow ol > li, +ol.workflow-steps > li { + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); +} + +.workflow ol > li:hover, +ol.workflow-steps > li:hover { + transform: translateX(4px); +} + +.workflow ol > li::before, +ol.workflow-steps > li::before { + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); +} + +.workflow ol > li:hover::before, +ol.workflow-steps > li:hover::before { + transform: scale(1.1); + box-shadow: var(--rm-shadow-md); +} + +/* ======================================== + Tab Switching Animations + ======================================== */ + +/* Tab content transition */ +.tabbed-block { + animation: fadeIn var(--rm-transition-normal) var(--rm-ease-smooth); +} + +.tabbed-block--active { + animation: slideDown var(--rm-transition-normal) var(--rm-ease-smooth); +} + +/* Tab label transition */ +.tabbed-labels > label, +.tabbed-set label { + position: relative; + transition: all var(--rm-transition-fast) var(--rm-ease-smooth); +} + +.tabbed-labels > label:hover, +.tabbed-set label:hover { + transform: translateY(-2px); +} + +/* Active tab indicator animation */ +.tabbed-labels > label::after, +.tabbed-set label::after { + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); +} + +/* ======================================== + Collapsible/Details Animation + ======================================== */ + +/* Details element smooth expand/collapse */ +article details, +.prose details, +.md-typeset details { + overflow: hidden; + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); +} + +article details[open], +.prose details[open], +.md-typeset details[open] { + animation: slideDown var(--rm-transition-normal) var(--rm-ease-smooth); +} + +article details summary, +.prose details summary, +.md-typeset details summary { + transition: all var(--rm-transition-fast) var(--rm-ease-smooth); +} + +article details summary:hover, +.prose details summary:hover, +.md-typeset details summary:hover { + background: var(--muted, #f3f4f6); + padding-left: 1.5rem; +} + +/* Arrow rotation animation is already in readability-enhancements.css */ + +/* ======================================== + Image & Media Effects + ======================================== */ + +/* Image lazy loading placeholder */ +img[loading="lazy"] { + background: linear-gradient( + 90deg, + var(--muted, #f3f4f6) 0%, + var(--muted-foreground, #e5e7eb) 50%, + var(--muted, #f3f4f6) 100% + ); + background-size: 200% 100%; + animation: shimmer 1.5s infinite; +} + +img[loading="lazy"].loaded { + animation: fadeIn 0.3s var(--rm-ease-out); + background: transparent; +} + +/* Image hover effect */ +article img, +.prose img, +.md-typeset img { + transition: all var(--rm-transition-normal) var(--rm-ease-smooth); +} + +article img:hover, +.prose img:hover, +.md-typeset img:hover { + transform: scale(1.02); + box-shadow: var(--rm-shadow-lg); +} + +/* Figure animation */ +article figure, +.prose figure, +.md-typeset figure { + animation: fadeInUp 0.6s var(--rm-ease-out); +} + +/* ======================================== + Table Hover Effects + ======================================== */ + +/* Table row hover (already in table-enhancements.css, just adding animation) */ +article table tbody tr, +.prose table tbody tr, +.md-typeset table tbody tr { + transition: all var(--rm-transition-fast) var(--rm-ease-smooth); +} + +/* ======================================== + Navigation & Sidebar Effects + ======================================== */ + +/* Sidebar items */ +.md-nav__item, +.md-nav__link, +nav li a { + transition: all var(--rm-transition-fast) var(--rm-ease-smooth); +} + +.md-nav__link:hover, +nav li a:hover { + transform: translateX(4px); + color: var(--primary, #3b82f6); +} + +/* Active nav item indicator */ +.md-nav__link--active, +nav li a.active, +nav li a[aria-current="page"] { + position: relative; +} + +.md-nav__link--active::before, +nav li a.active::before, +nav li a[aria-current="page"]::before { + content: ''; + position: absolute; + left: -1rem; + top: 50%; + transform: translateY(-50%); + width: 3px; + height: 70%; + background: var(--primary, #3b82f6); + border-radius: 2px; + animation: slideInLeft 0.3s var(--rm-ease-out); +} + +/* ======================================== + Loading States + ======================================== */ + +/* Skeleton loader */ +.skeleton { + background: linear-gradient( + 90deg, + var(--muted, #f3f4f6) 25%, + var(--muted-foreground, #e5e7eb) 50%, + var(--muted, #f3f4f6) 75% + ); + background-size: 200% 100%; + animation: shimmer 1.5s infinite; + border-radius: var(--radius, 0.375rem); +} + +/* Spinner */ +.spinner, +.loading-spinner { + display: inline-block; + width: 1em; + height: 1em; + border: 2px solid var(--muted, #e5e7eb); + border-top-color: var(--primary, #3b82f6); + border-radius: 50%; + animation: spin 0.8s linear infinite; +} + +/* ======================================== + Decorative Elements + ======================================== */ + +/* Gradient dividers */ +hr.gradient, +.divider-gradient { + height: 2px; + background: linear-gradient( + 90deg, + transparent 0%, + var(--primary, #3b82f6) 50%, + transparent 100% + ); + border: none; + margin: 3em 0; +} + +/* Animated gradient background (optional) */ +.hero-gradient, +.gradient-bg { + background: linear-gradient( + 135deg, + var(--primary, #3b82f6) 0%, + var(--primary-dark, #2563eb) 100% + ); + background-size: 200% 200%; + animation: gradientShift 8s ease infinite; +} + +@keyframes gradientShift { + 0%, 100% { + background-position: 0% 50%; + } + 50% { + background-position: 100% 50%; + } +} + +/* Glow effect for highlights */ +.glow, +.highlight-glow { + animation: glow 2s ease-in-out infinite; +} + +/* ======================================== + Scroll Animations + ======================================== */ + +/* Fade in elements on scroll (requires JS) */ +.fade-in-on-scroll { + opacity: 0; + transform: translateY(20px); + transition: opacity 0.6s var(--rm-ease-out), transform 0.6s var(--rm-ease-out); +} + +.fade-in-on-scroll.visible { + opacity: 1; + transform: translateY(0); +} + +/* Slide in from left on scroll */ +.slide-in-left { + opacity: 0; + transform: translateX(-40px); + transition: opacity 0.6s var(--rm-ease-out), transform 0.6s var(--rm-ease-out); +} + +.slide-in-left.visible { + opacity: 1; + transform: translateX(0); +} + +/* Slide in from right on scroll */ +.slide-in-right { + opacity: 0; + transform: translateX(40px); + transition: opacity 0.6s var(--rm-ease-out), transform 0.6s var(--rm-ease-out); +} + +.slide-in-right.visible { + opacity: 1; + transform: translateX(0); +} + +/* ======================================== + Focus States (Accessibility) + ======================================== */ + +/* Enhance focus indicators with animation */ +a:focus-visible, +button:focus-visible, +input:focus-visible, +textarea:focus-visible, +select:focus-visible { + outline: 2px solid var(--primary, #3b82f6); + outline-offset: 2px; + animation: pulse 0.4s var(--rm-ease-smooth); +} + +/* ======================================== + Special Effects + ======================================== */ + +/* Confetti effect (for success states) */ +@keyframes confetti { + 0% { + transform: translateY(0) rotate(0deg); + opacity: 1; + } + 100% { + transform: translateY(100vh) rotate(720deg); + opacity: 0; + } +} + +/* Shake (for errors) */ +@keyframes shake { + 0%, 100% { + transform: translateX(0); + } + 10%, 30%, 50%, 70%, 90% { + transform: translateX(-4px); + } + 20%, 40%, 60%, 80% { + transform: translateX(4px); + } +} + +.shake { + animation: shake 0.4s var(--rm-ease-smooth); +} + +/* Bounce in (for notifications) */ +@keyframes bounceIn { + 0% { + opacity: 0; + transform: scale(0.3); + } + 50% { + opacity: 1; + transform: scale(1.05); + } + 70% { + transform: scale(0.9); + } + 100% { + transform: scale(1); + } +} + +.bounce-in { + animation: bounceIn 0.6s var(--rm-ease-out); +} + +/* ======================================== + Performance Optimizations + ======================================== */ + +/* Hardware acceleration for smooth animations */ +.highlight, +.codehilite, +.admonition, +button, +.button, +a, +img { + will-change: auto; + backface-visibility: hidden; + -webkit-backface-visibility: hidden; +} + +/* Prevent layout shifts during animations */ +* { + transform: translateZ(0); +} + +/* ======================================== + Utility Classes + ======================================== */ + +.animate-fadeIn { animation: fadeIn 0.3s var(--rm-ease-out); } +.animate-fadeInUp { animation: fadeInUp 0.5s var(--rm-ease-out); } +.animate-fadeInDown { animation: fadeInDown 0.5s var(--rm-ease-out); } +.animate-slideUp { animation: slideUp 0.3s var(--rm-ease-out); } +.animate-slideDown { animation: slideDown 0.3s var(--rm-ease-out); } +.animate-slideInLeft { animation: slideInLeft 0.4s var(--rm-ease-out); } +.animate-slideInRight { animation: slideInRight 0.4s var(--rm-ease-out); } +.animate-scaleIn { animation: scaleIn 0.3s var(--rm-ease-out); } +.animate-pulse { animation: pulse 2s infinite; } +.animate-spin { animation: spin 1s linear infinite; } +.animate-bounce { animation: bounce 1s infinite; } +.animate-wiggle { animation: wiggle 0.5s var(--rm-ease-smooth); } + +/* Delay utilities */ +.delay-100 { animation-delay: 100ms; } +.delay-200 { animation-delay: 200ms; } +.delay-300 { animation-delay: 300ms; } +.delay-500 { animation-delay: 500ms; } + +/* Duration utilities */ +.duration-fast { animation-duration: var(--rm-transition-fast); } +.duration-normal { animation-duration: var(--rm-transition-normal); } +.duration-slow { animation-duration: var(--rm-transition-slow); } diff --git a/docs/stylesheets/code-enhancements.css b/docs/stylesheets/code-enhancements.css new file mode 100644 index 00000000..462435b7 --- /dev/null +++ b/docs/stylesheets/code-enhancements.css @@ -0,0 +1,513 @@ +/* + * Code Enhancements for OpenJudge Documentation + * Phase 1: 代码块样式增强 + * + * Features: + * - 行内代码样式优化 + * - 代码块圆角和边框 + * - 代码块标题栏样式 + * - 行号样式优化 + * - 代码复制按钮样式 + * - 语法高亮微调 + * - 长代码横向滚动指示 + */ + +/* ======================================== + Inline Code Styling + ======================================== */ + +article code:not(pre code), +.prose code:not(pre code), +.md-typeset code:not(pre code) { + font-family: 'JetBrains Mono', ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace; + font-size: 0.875em; + font-weight: 450; + padding: 0.2em 0.4em; + margin: 0 0.1em; + background: var(--muted, #f3f4f6); + border: 1px solid var(--border, #e5e7eb); + border-radius: 0.375rem; + color: var(--foreground, #1f2937); + word-break: break-word; + -webkit-font-smoothing: antialiased; +} + +/* Inline code in links - inherit link color */ +article a code:not(pre code), +.prose a code:not(pre code), +.md-typeset a code:not(pre code) { + color: inherit; + background: transparent; + border: 1px solid; + border-color: color-mix(in srgb, currentColor 30%, transparent); +} + +/* Fallback for browsers without color-mix support */ +@supports not (border-color: color-mix(in srgb, currentColor 30%, transparent)) { + article a code:not(pre code), + .prose a code:not(pre code), + .md-typeset a code:not(pre code) { + border-color: currentColor; + opacity: 0.8; + } +} + +/* Inline code in headings */ +article h1 code, article h2 code, article h3 code, +article h4 code, article h5 code, article h6 code, +.prose h1 code, .prose h2 code, .prose h3 code, +.prose h4 code, .prose h5 code, .prose h6 code { + font-size: 0.9em; +} + +/* ======================================== + Code Block Container + ======================================== */ + +article pre, +.prose pre, +.md-typeset pre { + margin: 1.5em 0; + padding: 0; + /* Fallback for browsers without OKLCH support */ + background: #ffffff; + background: var(--background, #ffffff); + border: 1px solid #e5e7eb; + border: 1px solid var(--border, #e5e7eb); + border-radius: var(--radius-lg, 0.5rem); + overflow: visible; + position: relative; +} + +/* Code inside pre */ +article pre code, +.prose pre code, +.md-typeset pre code { + display: block; + padding: 1rem 1.25rem; + overflow-x: auto; + overflow-y: auto; + max-height: 600px; + font-family: 'JetBrains Mono', ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace; + font-size: 0.8125rem; + line-height: 1.7; + background: transparent; + border: none; + border-radius: var(--radius-lg, 0.5rem); + -webkit-font-smoothing: antialiased; + tab-size: 2; +} + +/* ======================================== + Code Block with Title + ======================================== */ + +/* Title bar for code blocks (when using title="filename.py") */ +article .highlight .filename, +.prose .highlight .filename, +.md-typeset .highlight .filename { + display: block; + padding: 0.5rem 1rem; + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-size: 0.75rem; + font-weight: 500; + color: #6b7280; + color: var(--muted-foreground, #6b7280); + /* Fallback for browsers without OKLCH support */ + background: #ffffff; + background: var(--background, #ffffff); + border-bottom: 1px solid #e5e7eb; + border-bottom: 1px solid var(--border, #e5e7eb); + user-select: none; +} + +/* Code block with data-title attribute */ +article pre[data-title]::before, +.prose pre[data-title]::before, +.md-typeset pre[data-title]::before { + content: attr(data-title); + display: block; + padding: 0.5rem 1rem; + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-size: 0.75rem; + font-weight: 500; + color: #6b7280; + color: var(--muted-foreground, #6b7280); + /* Fallback for browsers without OKLCH support */ + background: #ffffff; + background: var(--background, #ffffff); + border-bottom: 1px solid #e5e7eb; + border-bottom: 1px solid var(--border, #e5e7eb); + user-select: none; +} + +/* ======================================== + Syntax Highlighting Wrapper + ======================================== */ + +article .highlight, +.prose .highlight, +.md-typeset .highlight { + margin: 1.5em 0; + border-radius: var(--radius-lg, 0.5rem); + overflow: visible; + /* Fallback for browsers without OKLCH support */ + border: 1px solid #e5e7eb; + border: 1px solid var(--border, #e5e7eb); + background: #ffffff; + background: var(--background, #ffffff); +} + +article .highlight pre, +.prose .highlight pre, +.md-typeset .highlight pre { + margin: 0; + border: none; + border-radius: var(--radius-lg, 0.5rem); + overflow: visible; +} + +/* ======================================== + Line Numbers + ======================================== */ + +/* Line number gutter */ +article .highlight .linenos, +article .highlight .linenodiv, +.prose .highlight .linenos, +.prose .highlight .linenodiv, +.md-typeset .highlight .linenos, +.md-typeset .highlight .linenodiv { + padding: 1rem 0; + padding-right: 1rem; + padding-left: 0.75rem; + text-align: right; + color: var(--muted-foreground, #9ca3af); + background: rgba(0, 0, 0, 0.02); + border-right: none; + user-select: none; + font-size: 0.75rem; + line-height: 1.7; +} + +/* Individual line numbers */ +article .highlight .linenos span, +article .highlight .linenodiv pre span, +.prose .highlight .linenos span, +.md-typeset .highlight .linenos span { + display: block; + line-height: 1.7; +} + +/* Highlighted line */ +article .highlight .hll, +.prose .highlight .hll, +.md-typeset .highlight .hll { + background: rgba(255, 213, 0, 0.15); + display: block; + margin: 0 -1.25rem; + padding: 0 1.25rem; +} + +/* ======================================== + Code Copy Button + ======================================== */ + +article .highlight .copy-button, +article pre .copy-button, +.prose .highlight .copy-button, +.md-typeset .highlight .copy-button, +button.copy-code-button, +.code-copy-btn { + position: absolute; + top: 0.5rem; + right: 0.5rem; + padding: 0.375rem 0.5rem; + font-size: 0.75rem; + font-weight: 500; + color: #6b7280; + color: var(--muted-foreground, #6b7280); + /* Fallback for browsers without OKLCH support */ + background: #ffffff; + background: var(--background, #fff); + border: 1px solid #e5e7eb; + border: 1px solid var(--border, #e5e7eb); + border-radius: 0.375rem; + cursor: pointer; + opacity: 0; + transition: all 0.15s ease; + z-index: 10; +} + +article .highlight:hover .copy-button, +article pre:hover .copy-button, +.prose .highlight:hover .copy-button, +.md-typeset .highlight:hover .copy-button, +.highlight:hover button.copy-code-button, +pre:hover .code-copy-btn { + opacity: 1; +} + +article .highlight .copy-button:hover, +article pre .copy-button:hover, +.prose .highlight .copy-button:hover, +.md-typeset .highlight .copy-button:hover, +button.copy-code-button:hover, +.code-copy-btn:hover { + color: var(--foreground, #1f2937); + background: var(--muted, #f3f4f6); + border-color: var(--border, #d1d5db); +} + +/* Copy button success state */ +article .highlight .copy-button.copied, +article pre .copy-button.copied, +button.copy-code-button.copied, +.code-copy-btn.copied { + color: var(--success, #10b981); + border-color: var(--success, #10b981); +} + +/* ======================================== + Language Label + ======================================== */ + +article .highlight[data-lang]::before, +.prose .highlight[data-lang]::before, +.md-typeset .highlight[data-lang]::before { + content: attr(data-lang); + position: absolute; + top: 0.5rem; + right: 3.5rem; + font-family: 'JetBrains Mono', monospace; + font-size: 0.625rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + color: var(--muted-foreground, #9ca3af); + opacity: 0.7; + pointer-events: none; +} + +/* ======================================== + Scroll Indicator (Horizontal & Vertical) + ======================================== */ + +article pre code, +.prose pre code, +.md-typeset pre code { + scrollbar-width: thin; + scrollbar-color: var(--muted-foreground, #9ca3af) transparent; +} + +/* Horizontal scrollbar */ +article pre code::-webkit-scrollbar, +.prose pre code::-webkit-scrollbar, +.md-typeset pre code::-webkit-scrollbar { + height: 6px; + width: 6px; +} + +article pre code::-webkit-scrollbar-track, +.prose pre code::-webkit-scrollbar-track, +.md-typeset pre code::-webkit-scrollbar-track { + background: transparent; +} + +article pre code::-webkit-scrollbar-thumb, +.prose pre code::-webkit-scrollbar-thumb, +.md-typeset pre code::-webkit-scrollbar-thumb { + background: var(--muted-foreground, #d1d5db); + border-radius: 3px; +} + +article pre code::-webkit-scrollbar-thumb:hover, +.prose pre code::-webkit-scrollbar-thumb:hover, +.md-typeset pre code::-webkit-scrollbar-thumb:hover { + background: var(--foreground, #9ca3af); +} + +/* Scrollbar corner (when both scrollbars are present) */ +article pre code::-webkit-scrollbar-corner, +.prose pre code::-webkit-scrollbar-corner, +.md-typeset pre code::-webkit-scrollbar-corner { + background: transparent; +} + +/* ======================================== + Dark Mode + ======================================== */ + +/* Dark mode inline code */ +.dark article code:not(pre code), +.dark .prose code:not(pre code), +.dark .md-typeset code:not(pre code), +.dark article code:not(pre code), +.dark .prose code:not(pre code), +.dark .md-typeset code:not(pre code) { + background: var(--muted, #1f2937); + border-color: var(--border, #374151); + color: var(--foreground, #e5e7eb); +} + +/* Dark mode code blocks */ +.dark article pre, +.dark .prose pre, +.dark .md-typeset pre, +.dark article .highlight, +.dark .prose .highlight, +.dark .md-typeset .highlight { + /* Fallback for browsers without OKLCH support */ + background: #0a0a0a; + background: var(--background, #0a0a0a); + border-color: #374151; + border-color: var(--border, #374151); +} + +/* Dark mode code color is handled by syntax-highlight.css */ + +/* Dark mode title bar */ +.dark article .highlight .filename, +.dark .prose .highlight .filename, +.dark .md-typeset .highlight .filename { + /* Fallback for browsers without OKLCH support */ + background: #0a0a0a; + background: var(--background, #0a0a0a); + border-bottom-color: #374151; + border-bottom-color: var(--border, #374151); + color: #9ca3af; + color: var(--muted-foreground, #9ca3af); +} + +/* Dark mode line numbers */ +.dark article .highlight .linenos, +.dark article .highlight .linenodiv, +.dark .prose .highlight .linenos, +.dark .md-typeset .highlight .linenos, +.dark article .highlight .linenos, +.dark article .highlight .linenodiv, +.dark .prose .highlight .linenos, +.dark .md-typeset .highlight .linenos { + background: rgba(255, 255, 255, 0.02); + border-right: none; + color: var(--muted-foreground, #6b7280); +} + +/* Dark mode highlighted line */ +.dark article .highlight .hll, +.dark .prose .highlight .hll, +.dark .md-typeset .highlight .hll, +.dark article .highlight .hll, +.dark .prose .highlight .hll, +.dark .md-typeset .highlight .hll { + background: rgba(255, 213, 0, 0.1); +} + +/* Dark mode copy button */ +.dark article .highlight .copy-button, +.dark article pre .copy-button, +.dark button.copy-code-button, +.dark .code-copy-btn { + /* Fallback for browsers without OKLCH support */ + background: #1f2937; + background: var(--background, #1f2937); + border-color: #374151; + border-color: var(--border, #374151); + color: #9ca3af; + color: var(--muted-foreground, #9ca3af); +} + +.dark article .highlight .copy-button:hover, +.dark article pre .copy-button:hover, +.dark button.copy-code-button:hover, +.dark .code-copy-btn:hover { + /* Fallback for browsers without OKLCH support */ + background: #374151; + background: var(--muted, #374151); + color: #e5e7eb; + color: var(--foreground, #e5e7eb); +} + +/* Dark mode scrollbar */ +.dark article pre code::-webkit-scrollbar-thumb, +.dark .prose pre code::-webkit-scrollbar-thumb, +.dark .md-typeset pre code::-webkit-scrollbar-thumb, +.dark article pre code::-webkit-scrollbar-thumb, +.dark .prose pre code::-webkit-scrollbar-thumb, +.dark .md-typeset pre code::-webkit-scrollbar-thumb { + background: var(--muted-foreground, #4b5563); +} + +/* ======================================== + Responsive + ======================================== */ + +@media (max-width: 640px) { + article code:not(pre code), + .prose code:not(pre code), + .md-typeset code:not(pre code) { + font-size: 0.8125em; + padding: 0.15em 0.35em; + } + + article pre code, + .prose pre code, + .md-typeset pre code { + padding: 0.875rem 1rem; + font-size: 0.75rem; + max-height: 400px; + } + + article .highlight .copy-button, + article pre .copy-button, + button.copy-code-button, + .code-copy-btn { + opacity: 1; + padding: 0.25rem 0.375rem; + font-size: 0.6875rem; + } +} + +/* ======================================== + Special Code Block Styles + ======================================== */ + +/* Terminal/Shell style */ +article pre.terminal code, +article .highlight.terminal pre code, +.prose pre.terminal code { + color: #22c55e; +} + +article pre.terminal code::before, +article .highlight.terminal pre code::before, +.prose pre.terminal code::before { + content: '$ '; + color: #9ca3af; + user-select: none; +} + +/* Output style (muted) */ +article pre.output code, +article .highlight.output pre code, +.prose pre.output code { + color: var(--muted-foreground, #6b7280); + font-style: italic; +} + +/* Diff style enhancements */ +article .highlight .gi, +.prose .highlight .gi, +.md-typeset .highlight .gi { + background: rgba(34, 197, 94, 0.15); + display: inline-block; + width: 100%; +} + +article .highlight .gd, +.prose .highlight .gd, +.md-typeset .highlight .gd { + background: rgba(239, 68, 68, 0.15); + display: inline-block; + width: 100%; +} diff --git a/docs/stylesheets/feature-cards.css b/docs/stylesheets/feature-cards.css new file mode 100644 index 00000000..5865ca73 --- /dev/null +++ b/docs/stylesheets/feature-cards.css @@ -0,0 +1,542 @@ +/* Feature Cards Styles */ +/* Supports dark mode and hover effects */ + +/* Card Container */ +.card-grid { + display: flex; + flex-wrap: wrap; + gap: 20px; + margin: 1rem 0; +} + +.card-grid-2 { + composes: card-grid; +} + +.card-grid-3 { + composes: card-grid; +} + +/* Base Card Style */ +.feature-card { + flex: 1 1 45%; + min-width: 280px; + text-decoration: none; + color: inherit; + border: 1px solid var(--md-default-fg-color--lightest, #e0e0e0); + border-radius: 12px; + padding: 20px; + transition: all 0.25s ease; + background: var(--md-default-bg-color, #fff); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.04); + cursor: pointer; + display: block; +} + +.feature-card:hover { + transform: translateY(-3px); + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.1); + border-color: var(--md-primary-fg-color, #4051b5); + text-decoration: none; + color: inherit; +} + +/* Three column cards */ +.feature-card-sm { + flex: 1 1 30%; + min-width: 250px; + text-decoration: none; + color: inherit; + border: 1px solid var(--md-default-fg-color--lightest, #e0e0e0); + border-radius: 12px; + padding: 20px; + transition: all 0.25s ease; + background: var(--md-default-bg-color, #fff); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.04); + cursor: pointer; + display: block; +} + +.feature-card-sm:hover { + transform: translateY(-3px); + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.1); + border-color: var(--md-primary-fg-color, #4051b5); + text-decoration: none; + color: inherit; +} + +/* Work in Progress Card */ +.feature-card-wip { + flex: 1 1 30%; + min-width: 250px; + text-decoration: none; + color: inherit; + border: 1px dashed var(--md-default-fg-color--light, #b0b0b0); + border-radius: 12px; + padding: 20px; + transition: all 0.25s ease; + background: var(--md-default-bg-color--light, #fafafa); + box-shadow: none; + opacity: 0.65; + pointer-events: none; + cursor: default; +} + +.feature-card-wip:hover { + transform: none; + box-shadow: none; + text-decoration: none; + color: inherit; +} + +/* Card Header */ +.card-header { + display: inline-flex !important; + align-items: center !important; + flex-wrap: nowrap !important; + margin-bottom: 12px; + white-space: nowrap; + pointer-events: none; +} + +.card-header h3 { + margin: 0 !important; + font-size: 16px; + font-weight: 600; + white-space: nowrap !important; + display: inline !important; + pointer-events: none; +} + +.card-header-lg h3 { + font-size: 18px; +} + +/* Card Icon */ +.card-icon { + height: 1.3em; + width: 1.3em; + min-width: 1.3em; + margin-right: 10px; + opacity: 0.9; + transition: all 0.25s ease; + flex-shrink: 0; + pointer-events: none; +} + +.feature-card:hover .card-icon, +.feature-card-sm:hover .card-icon, +.feature-card-wip:hover .card-icon { + opacity: 1; + transform: scale(1.1); +} + +/* Icon Colors by Category */ +.card-icon-agent { + filter: invert(45%) sepia(80%) saturate(500%) hue-rotate(190deg) brightness(95%); +} + +.card-icon-general { + filter: invert(50%) sepia(60%) saturate(400%) hue-rotate(100deg) brightness(95%); +} + +.card-icon-multimodal { + filter: invert(40%) sepia(70%) saturate(500%) hue-rotate(250deg) brightness(95%); +} + +.card-icon-math { + filter: invert(55%) sepia(70%) saturate(500%) hue-rotate(10deg) brightness(95%); +} + +.card-icon-tool { + filter: invert(45%) sepia(60%) saturate(400%) hue-rotate(170deg) brightness(95%); +} + +.card-icon-data { + filter: invert(50%) sepia(60%) saturate(450%) hue-rotate(130deg) brightness(95%); +} + +.card-icon-integration { + filter: invert(45%) sepia(70%) saturate(450%) hue-rotate(220deg) brightness(95%); +} + +/* Card Description */ +.card-desc { + margin: 0; + font-size: 13px; + opacity: 0.8; + line-height: 1.6; + pointer-events: none; +} + +.card-desc-lg { + font-size: 14px; +} + +/* Make all children non-interactive so clicks pass through to the link */ +.feature-card *, +.feature-card-sm * { + pointer-events: none; +} + +/* Badge for Work in Progress */ +.badge-wip { + font-size: 12px; + background-color: var(--md-warning-fg-color--light, #fff3cd); + color: var(--md-warning-fg-color, #856404); + padding: 2px 10px; + border-radius: 12px; + margin-left: 10px; + font-weight: 500; +} + +/* Callout Tip - Highlighted intro section */ +.callout-tip { + background: linear-gradient(135deg, rgba(245, 158, 11, 0.04) 0%, rgba(245, 158, 11, 0.01) 100%); + border: 1px solid rgba(245, 158, 11, 0.1); + border-left: 3px solid rgba(245, 158, 11, 0.5); + border-radius: 10px; + padding: 18px 22px; + margin: 1.5rem 0; + position: relative; +} + +.callout-tip p { + margin: 0; + line-height: 1.7; + font-size: 15px; +} + +.callout-tip .callout-icon { + height: 1.3em; + width: 1.3em; + margin-right: 10px; + vertical-align: middle; + display: inline-block; + opacity: 0.6; + filter: invert(60%) sepia(50%) saturate(400%) hue-rotate(5deg) brightness(100%); + flex-shrink: 0; +} + +/* Dark Mode for Callout Tip */ +.dark .callout-tip { + background: linear-gradient(135deg, rgba(245, 158, 11, 0.08) 0%, rgba(245, 158, 11, 0.02) 100%); + border-color: rgba(245, 158, 11, 0.15); + border-left-color: rgba(245, 158, 11, 0.6); +} + +.dark .callout-tip .callout-icon { + filter: invert(75%) sepia(60%) saturate(500%) hue-rotate(5deg) brightness(110%); +} + +/* Key Features Section */ +.key-features { + background: var(--md-default-bg-color, #fff); + border-radius: 12px; + padding: 8px; + margin: 1rem 0; +} + +.key-features ul { + margin: 0; + padding-left: 0; + list-style: none; +} + +.key-features > ul > li { + margin-bottom: 12px; + padding: 16px 20px; + border-radius: 10px; + border: 1px solid transparent; + transition: all 0.2s ease; +} + +.key-features > ul > li:last-child { + margin-bottom: 0; +} + +/* Feature 1: Library - Blue */ +.key-features > ul > li:nth-child(1) { + background: linear-gradient(135deg, rgba(59, 130, 246, 0.08) 0%, rgba(59, 130, 246, 0.02) 100%); + border-color: rgba(59, 130, 246, 0.12); +} + +.key-features > ul > li:nth-child(1):hover { + background: linear-gradient(135deg, rgba(59, 130, 246, 0.12) 0%, rgba(59, 130, 246, 0.04) 100%); + border-color: rgba(59, 130, 246, 0.18); +} + +/* Feature 2: Building - Green */ +.key-features > ul > li:nth-child(2) { + background: linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(16, 185, 129, 0.02) 100%); + border-color: rgba(16, 185, 129, 0.12); +} + +.key-features > ul > li:nth-child(2):hover { + background: linear-gradient(135deg, rgba(16, 185, 129, 0.12) 0%, rgba(16, 185, 129, 0.04) 100%); + border-color: rgba(16, 185, 129, 0.18); +} + +/* Feature 3: Integration - Purple */ +.key-features > ul > li:nth-child(3) { + background: linear-gradient(135deg, rgba(139, 92, 246, 0.08) 0%, rgba(139, 92, 246, 0.02) 100%); + border-color: rgba(139, 92, 246, 0.12); +} + +.key-features > ul > li:nth-child(3):hover { + background: linear-gradient(135deg, rgba(139, 92, 246, 0.12) 0%, rgba(139, 92, 246, 0.04) 100%); + border-color: rgba(139, 92, 246, 0.18); +} + +.key-features ul ul { + margin-top: 10px; + padding-left: 20px; +} + +.key-features ul ul li { + margin-bottom: 6px; + position: relative; + opacity: 0.85; +} + +/* Removed arrow decoration for cleaner appearance */ +/* .key-features ul ul li::before { + content: "›"; + position: absolute; + left: -16px; + font-weight: bold; + opacity: 0.5; +} + +.key-features > ul > li:nth-child(1) ul li::before { + color: #3b82f6; +} + +.key-features > ul > li:nth-child(2) ul li::before { + color: #10b981; +} + +.key-features > ul > li:nth-child(3) ul li::before { + color: #8b5cf6; +} */ + +/* Dark Mode Adjustments */ +.dark .feature-card, +.dark .feature-card-sm { + background: rgba(30, 30, 30, 0.6); + border-color: rgba(255, 255, 255, 0.1); +} + +.dark .feature-card:hover, +.dark .feature-card-sm:hover { + background: rgba(40, 40, 40, 0.8); + box-shadow: 0 8px 24px rgba(0, 0, 0, 0.4); + border-color: rgba(255, 255, 255, 0.2); +} + +.dark .feature-card-wip { + background: rgba(30, 30, 30, 0.4); + border-color: rgba(255, 255, 255, 0.08); +} + +.dark .key-features { + background: rgba(255, 255, 255, 0.02); +} + +/* Feature 1: Library - Blue (Dark Mode) */ +.dark .key-features > ul > li:nth-child(1) { + background: linear-gradient(135deg, rgba(59, 130, 246, 0.18) 0%, rgba(59, 130, 246, 0.05) 100%); + border: 1px solid rgba(59, 130, 246, 0.2); +} + +.dark .key-features > ul > li:nth-child(1):hover { + background: linear-gradient(135deg, rgba(59, 130, 246, 0.28) 0%, rgba(59, 130, 246, 0.08) 100%); + border-color: rgba(59, 130, 246, 0.35); +} + +/* Feature 2: Building - Green (Dark Mode) */ +.dark .key-features > ul > li:nth-child(2) { + background: linear-gradient(135deg, rgba(16, 185, 129, 0.18) 0%, rgba(16, 185, 129, 0.05) 100%); + border: 1px solid rgba(16, 185, 129, 0.2); +} + +.dark .key-features > ul > li:nth-child(2):hover { + background: linear-gradient(135deg, rgba(16, 185, 129, 0.28) 0%, rgba(16, 185, 129, 0.08) 100%); + border-color: rgba(16, 185, 129, 0.35); +} + +/* Feature 3: Integration - Purple (Dark Mode) */ +.dark .key-features > ul > li:nth-child(3) { + background: linear-gradient(135deg, rgba(139, 92, 246, 0.18) 0%, rgba(139, 92, 246, 0.05) 100%); + border: 1px solid rgba(139, 92, 246, 0.2); +} + +.dark .key-features > ul > li:nth-child(3):hover { + background: linear-gradient(135deg, rgba(139, 92, 246, 0.28) 0%, rgba(139, 92, 246, 0.08) 100%); + border-color: rgba(139, 92, 246, 0.35); +} + +.dark .badge-wip { + background-color: rgba(255, 193, 7, 0.25); + color: #ffc107; +} + +/* Ensure text readability in dark mode */ +.dark .key-features ul ul li { + opacity: 0.9; +} + +.dark .key-features strong, +.dark .key-features b { + color: rgba(255, 255, 255, 0.95); +} + +/* Key Features Sub-point Links */ +.key-features ul ul li a.feature-link { + display: inline-flex; + align-items: center; + gap: 4px; + font-size: 12px; + font-weight: 500; + text-decoration: none; + padding: 3px 10px; + border-radius: 4px; + margin-left: 8px; + transition: all 0.2s ease; + vertical-align: baseline; + pointer-events: auto; + white-space: nowrap; + line-height: 1.4; +} + +.key-features ul ul li a.feature-link .link-arrow { + font-size: 11px; + transition: transform 0.2s ease; + display: inline-block; +} + +.key-features ul ul li a.feature-link:hover .link-arrow { + transform: translateX(2px); +} + +/* Blue links for Feature 1 (Library) */ +.key-features > ul > li:nth-child(1) ul li a.feature-link { + color: #3b82f6; + background: rgba(59, 130, 246, 0.08); + border: 1px solid rgba(59, 130, 246, 0.15); +} + +.key-features > ul > li:nth-child(1) ul li a.feature-link:hover { + background: rgba(59, 130, 246, 0.15); + border-color: rgba(59, 130, 246, 0.3); + text-decoration: none; +} + +/* Green links for Feature 2 (Building) */ +.key-features > ul > li:nth-child(2) ul li a.feature-link { + color: #10b981; + background: rgba(16, 185, 129, 0.08); + border: 1px solid rgba(16, 185, 129, 0.15); +} + +.key-features > ul > li:nth-child(2) ul li a.feature-link:hover { + background: rgba(16, 185, 129, 0.15); + border-color: rgba(16, 185, 129, 0.3); + text-decoration: none; +} + +/* Purple links for Feature 3 (Integration) */ +.key-features > ul > li:nth-child(3) ul li a.feature-link { + color: #8b5cf6; + background: rgba(139, 92, 246, 0.08); + border: 1px solid rgba(139, 92, 246, 0.15); +} + +.key-features > ul > li:nth-child(3) ul li a.feature-link:hover { + background: rgba(139, 92, 246, 0.15); + border-color: rgba(139, 92, 246, 0.3); + text-decoration: none; +} + +/* Dark mode adjustments for feature links */ +.dark .key-features > ul > li:nth-child(1) ul li a.feature-link { + color: #60a5fa; + background: rgba(59, 130, 246, 0.15); + border-color: rgba(59, 130, 246, 0.25); +} + +.dark .key-features > ul > li:nth-child(1) ul li a.feature-link:hover { + background: rgba(59, 130, 246, 0.25); + border-color: rgba(59, 130, 246, 0.4); +} + +.dark .key-features > ul > li:nth-child(2) ul li a.feature-link { + color: #34d399; + background: rgba(16, 185, 129, 0.15); + border-color: rgba(16, 185, 129, 0.25); +} + +.dark .key-features > ul > li:nth-child(2) ul li a.feature-link:hover { + background: rgba(16, 185, 129, 0.25); + border-color: rgba(16, 185, 129, 0.4); +} + +.dark .key-features > ul > li:nth-child(3) ul li a.feature-link { + color: #a78bfa; + background: rgba(139, 92, 246, 0.15); + border-color: rgba(139, 92, 246, 0.25); +} + +.dark .key-features > ul > li:nth-child(3) ul li a.feature-link:hover { + background: rgba(139, 92, 246, 0.25); + border-color: rgba(139, 92, 246, 0.4); +} + +/* Card with arrow indicator */ +.feature-card::after, +.feature-card-sm::after { + content: "→"; + position: absolute; + right: 16px; + bottom: 16px; + opacity: 0; + transition: all 0.25s ease; + color: var(--md-primary-fg-color, #4051b5); +} + +.feature-card:hover::after, +.feature-card-sm:hover::after { + opacity: 0.6; + right: 12px; +} + +.feature-card, +.feature-card-sm, +.feature-card-wip { + position: relative; +} + +/* Inline Icons for text content (replacing emoji) */ +.inline-icon { + height: 1em; + width: 1em; + vertical-align: -0.125em; + display: inline; + opacity: 0.85; + filter: var(--inline-icon-filter, none); +} + +/* Dark mode filter for inline icons */ +.dark .inline-icon { + filter: invert(1) hue-rotate(180deg); + opacity: 0.9; +} + +/* Alternative: use CSS variables for color control */ +:root { + --inline-icon-filter: none; +} + +.dark { + --inline-icon-filter: invert(1) hue-rotate(180deg); +} diff --git a/docs/stylesheets/flowchart.css b/docs/stylesheets/flowchart.css new file mode 100644 index 00000000..345b94f1 --- /dev/null +++ b/docs/stylesheets/flowchart.css @@ -0,0 +1,402 @@ +/* Flowchart Component Styling */ +/* Modern, card-based flowchart with dark mode support */ + +/* ======================================== + Flowchart Container + ======================================== */ + +.flowchart-container { + margin: 2rem 0; + padding: 0; + display: flex; + flex-direction: row; + align-items: center; + justify-content: center; + gap: 0; + width: 100%; + overflow-x: auto; +} + +/* ======================================== + Flowchart Box Styles + ======================================== */ + +.flowchart-box { + position: relative; + padding: 1.5rem 1.75rem; + margin: 0; + width: 24rem; + min-width: 22rem; + flex-shrink: 0; + border-radius: 0.75rem; + border: 1px solid #e5e7eb; + background: linear-gradient(135deg, #ffffff 0%, #fafbfc 100%); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.04), 0 1px 2px rgba(0, 0, 0, 0.06); + transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); + overflow: visible; +} + +.flowchart-box:hover { + border-color: #d1d5db; + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08), 0 2px 4px rgba(0, 0, 0, 0.06); + transform: translateY(-2px); +} + +/* Dark mode */ +.dark .flowchart-box { + border-color: #374151; + background: linear-gradient(135deg, #1f2937 0%, #1a2332 100%); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.3), 0 1px 2px rgba(0, 0, 0, 0.4); +} + +.dark .flowchart-box:hover { + border-color: #4b5563; + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.5), 0 2px 4px rgba(0, 0, 0, 0.4); +} + +/* ======================================== + Box Header (Title) + ======================================== */ + +.flowchart-box-header { + font-size: 1rem; + font-weight: 700; + color: #111827; + margin-bottom: 1rem; + padding-bottom: 0.75rem; + border-bottom: 2px solid #e5e7eb; + display: flex; + align-items: center; + gap: 0.625rem; + white-space: nowrap; + overflow: visible; +} + +.dark .flowchart-box-header { + color: #f3f4f6; + border-bottom-color: #374151; +} + +/* Header icon */ +.flowchart-box-header::before { + content: ""; + display: inline-block; + width: 1.25rem; + height: 1.25rem; + flex-shrink: 0; + background-color: #6b7280; + mask-repeat: no-repeat; + mask-position: center; + mask-size: contain; +} + +.dark .flowchart-box-header::before { + background-color: #9ca3af; +} + +/* Input box icon */ +.flowchart-box.input .flowchart-box-header::before { + mask-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='20' height='20' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpath d='M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z'%3E%3C/path%3E%3Cpolyline points='14 2 14 8 20 8'%3E%3C/polyline%3E%3C/svg%3E"); + background-color: #3b82f6; +} + +.dark .flowchart-box.input .flowchart-box-header::before { + background-color: #60a5fa; +} + +/* Grader box icon */ +.flowchart-box.grader .flowchart-box-header::before { + mask-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='20' height='20' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpolyline points='20 6 9 17 4 12'%3E%3C/polyline%3E%3C/svg%3E"); + background-color: #10b981; +} + +.dark .flowchart-box.grader .flowchart-box-header::before { + background-color: #34d399; +} + +/* Output box icon */ +.flowchart-box.output .flowchart-box-header::before { + mask-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='20' height='20' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpath d='M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4'%3E%3C/path%3E%3Cpolyline points='7 10 12 15 17 10'%3E%3C/polyline%3E%3Cline x1='12' y1='15' x2='12' y2='3'%3E%3C/line%3E%3C/svg%3E"); + background-color: #8b5cf6; +} + +.dark .flowchart-box.output .flowchart-box-header::before { + background-color: #a78bfa; +} + +/* ======================================== + Box Content (List Items) + ======================================== */ + +.flowchart-box-content { + margin: 0; + padding: 0; + list-style: none; + overflow: visible; +} + +.flowchart-box-content li { + position: relative; + padding-left: 1.75rem; + margin-bottom: 0.625rem; + font-size: 0.9375rem; + line-height: 1.6; + color: #4b5563; + word-wrap: break-word; + overflow-wrap: break-word; +} + +.flowchart-box-content li:last-child { + margin-bottom: 0; +} + +.dark .flowchart-box-content li { + color: #d1d5db; +} + +/* List item bullet */ +.flowchart-box-content li::before { + content: ""; + position: absolute; + left: 0; + top: 0.5rem; + width: 0.375rem; + height: 0.375rem; + background: #9ca3af; + border-radius: 50%; +} + +.dark .flowchart-box-content li::before { + background: #6b7280; +} + +/* Label styling (e.g., "Query", "Response") */ +.flowchart-box-content li strong { + font-weight: 600; + color: #1f2937; +} + +.dark .flowchart-box-content li strong { + color: #f9fafb; +} + +/* Tag styling (e.g., "(optional)", "(required)") */ +.flowchart-box-content li em { + font-style: normal; + font-size: 0.8125rem; + font-weight: 500; + padding: 0.125rem 0.5rem; + margin-left: 0.5rem; + border-radius: 0.25rem; + background: rgba(59, 130, 246, 0.1); + color: #3b82f6; +} + +.flowchart-box-content li em.optional { + background: rgba(107, 114, 128, 0.1); + color: #6b7280; +} + +.dark .flowchart-box-content li em { + background: rgba(96, 165, 250, 0.15); + color: #60a5fa; +} + +.dark .flowchart-box-content li em.optional { + background: rgba(156, 163, 175, 0.15); + color: #9ca3af; +} + +/* Nested list for sub-items */ +.flowchart-box-content ul { + margin: 0.5rem 0 0 0; + padding-left: 1.25rem; + list-style: none; + overflow: visible; +} + +.flowchart-box-content ul li { + font-size: 0.875rem; + color: #6b7280; + padding-left: 1.5rem; + word-wrap: break-word; + overflow-wrap: break-word; +} + +.dark .flowchart-box-content ul li { + color: #9ca3af; +} + +.flowchart-box-content ul li::before { + width: 0.25rem; + height: 0.25rem; + top: 0.5rem; +} + +/* ======================================== + Arrow Connector + ======================================== */ + +.flowchart-arrow { + position: relative; + display: flex; + justify-content: center; + align-items: center; + width: 3rem; + height: auto; + margin: 0; + flex-shrink: 0; +} + +.flowchart-arrow::before { + content: ""; + width: 100%; + height: 2px; + background: linear-gradient(90deg, #d1d5db 0%, #9ca3af 50%, #d1d5db 100%); + position: absolute; + top: 50%; + transform: translateY(-50%); +} + +.dark .flowchart-arrow::before { + background: linear-gradient(90deg, #4b5563 0%, #6b7280 50%, #4b5563 100%); +} + +/* Arrow icon */ +.flowchart-arrow::after { + content: ""; + width: 1.5rem; + height: 1.5rem; + background-color: #6b7280; + mask-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2.5' stroke-linecap='round' stroke-linejoin='round'%3E%3Cline x1='5' y1='12' x2='19' y2='12'%3E%3C/line%3E%3Cpolyline points='12 5 19 12 12 19'%3E%3C/polyline%3E%3C/svg%3E"); + mask-repeat: no-repeat; + mask-position: center; + mask-size: contain; + position: relative; + animation: arrow-bounce-horizontal 2s ease-in-out infinite; +} + +.dark .flowchart-arrow::after { + background-color: #9ca3af; +} + +/* Arrow bounce animation (horizontal) */ +@keyframes arrow-bounce-horizontal { + 0%, 100% { + transform: translateX(0); + } + 50% { + transform: translateX(4px); + } +} + +/* ======================================== + Responsive Design + ======================================== */ + +@media (max-width: 1024px) { + .flowchart-container { + flex-direction: column; + gap: 0; + } + + .flowchart-box { + width: 100%; + max-width: 42rem; + margin: 0 auto; + } + + .flowchart-arrow { + width: auto; + height: 2.5rem; + margin: 0 auto; + } + + .flowchart-arrow::before { + width: 2px; + height: 100%; + background: linear-gradient(180deg, #d1d5db 0%, #9ca3af 50%, #d1d5db 100%); + top: 0; + left: 50%; + transform: translateX(-50%); + } + + .dark .flowchart-arrow::before { + background: linear-gradient(180deg, #4b5563 0%, #6b7280 50%, #4b5563 100%); + } + + .flowchart-arrow::after { + mask-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='24' height='24' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2.5' stroke-linecap='round' stroke-linejoin='round'%3E%3Cline x1='12' y1='5' x2='12' y2='19'%3E%3C/line%3E%3Cpolyline points='19 12 12 19 5 12'%3E%3C/polyline%3E%3C/svg%3E"); + animation: arrow-bounce 2s ease-in-out infinite; + } + + @keyframes arrow-bounce { + 0%, 100% { + transform: translateY(0); + } + 50% { + transform: translateY(4px); + } + } +} + +@media (max-width: 640px) { + .flowchart-box { + padding: 1.25rem 1.5rem; + } + + .flowchart-box-header { + font-size: 0.9375rem; + } + + .flowchart-box-content li { + font-size: 0.875rem; + } + + .flowchart-arrow { + height: 2rem; + } +} + +/* ======================================== + Special Variants + ======================================== */ + +/* Highlighted box variant */ +.flowchart-box.highlight { + border-color: #10b981; + background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%); + box-shadow: 0 2px 8px rgba(16, 185, 129, 0.15), 0 1px 2px rgba(16, 185, 129, 0.1); +} + +.flowchart-box.highlight:hover { + border-color: #059669; + box-shadow: 0 4px 12px rgba(16, 185, 129, 0.2), 0 2px 4px rgba(16, 185, 129, 0.15); +} + +.dark .flowchart-box.highlight { + border-color: #34d399; + background: linear-gradient(135deg, #064e3b 0%, #065f46 100%); + box-shadow: 0 2px 8px rgba(52, 211, 153, 0.3), 0 1px 2px rgba(52, 211, 153, 0.2); +} + +.dark .flowchart-box.highlight:hover { + border-color: #10b981; + box-shadow: 0 4px 12px rgba(52, 211, 153, 0.4), 0 2px 4px rgba(52, 211, 153, 0.3); +} + +/* Compact variant */ +.flowchart-box.compact { + padding: 1rem 1.25rem; +} + +.flowchart-box.compact .flowchart-box-header { + font-size: 0.9375rem; + margin-bottom: 0.75rem; + padding-bottom: 0.5rem; +} + +.flowchart-box.compact .flowchart-box-content li { + font-size: 0.875rem; + margin-bottom: 0.5rem; +} diff --git a/docs/stylesheets/jupyter-simple.css b/docs/stylesheets/jupyter-simple.css new file mode 100644 index 00000000..864c59bd --- /dev/null +++ b/docs/stylesheets/jupyter-simple.css @@ -0,0 +1,258 @@ +/* Jupyter notebook presentation styles for OpenJudge docs */ +/* Applies to common nbconvert / mkdocs-jupyter markup */ + +article .jupyter-notebook, +article .nb-notebook, +article .jp-Notebook, +article .notebook { + display: block; + margin: 2.5rem 0; + gap: 1.5rem; +} + +article .jupyter-notebook .cell, +article .nb-notebook .cell, +article .jp-Notebook .jp-Cell, +article .notebook .cell, +article .jupyter-cell, +article .nb-cell { + position: relative; + margin: 1.75rem 0; + border: 1px solid var(--border, rgba(148, 163, 184, 0.35)); + border-radius: var(--radius-lg, 0.75rem); + background: var(--card, #ffffff); + overflow: hidden; + box-shadow: 0 18px 38px rgba(15, 23, 42, 0.08); +} + +article .jupyter-notebook .cell:first-of-type, +article .nb-notebook .cell:first-of-type, +article .notebook .cell:first-of-type { + margin-top: 0; +} + +/* Cell header ribbon */ +/* Support both data-type attribute and standard nbconvert class names */ +article .jupyter-notebook .cell::before, +article .nb-notebook .cell::before, +article .jp-Notebook .jp-Cell::before, +article .jupyter-cell::before, +article .nb-cell::before, +article .cell.code_cell::before, +article .cell.text_cell::before, +article .cell.markdown::before { + content: ""; + position: absolute; + top: 0.85rem; + left: 1rem; + padding: 0.15rem 0.55rem; + font-size: 0.75rem; + font-weight: 600; + line-height: 1rem; + letter-spacing: 0.04em; + text-transform: uppercase; + color: var(--muted-foreground, #4b5563); + background: rgba(148, 163, 184, 0.16); + border-radius: 999px; +} + +/* Markdown cells - via data-type or class */ +article .jupyter-notebook .cell[data-type="markdown"]::before, +article .nb-notebook .cell[data-type="markdown"]::before, +article .jupyter-cell[data-type="markdown"]::before, +article .cell.text_cell::before, +article .cell.markdown::before { + content: "Markdown"; +} + +/* Code cells - via data-type or class */ +article .jupyter-notebook .cell[data-type="code"]::before, +article .nb-notebook .cell[data-type="code"]::before, +article .jupyter-cell[data-type="code"]::before, +article .cell.code_cell::before { + content: "Code"; + color: var(--primary-foreground, #0f172a); + background: rgba(14, 165, 233, 0.15); +} + +/* Input (code) area */ +article .cell .input, +article .cell .input_area, +article .jupyter-cell .input, +article .nbinput, +article .jp-InputArea { + display: grid; + grid-template-columns: minmax(3.5rem, auto) minmax(0, 1fr); + gap: 0.5rem 1rem; + padding: 1.5rem 1.75rem 1.1rem; + background: var(--muted, rgba(15, 23, 42, 0.04)); + border-bottom: 1px solid var(--border, rgba(148, 163, 184, 0.28)); +} + +article .cell .input pre, +article .nbinput pre, +article .jp-InputArea pre { + margin: 0; + border-radius: var(--radius-md, 0.5rem); + background: transparent; + /* Inherit syntax highlighting background from Pygments or theme */ +} + +article .input_prompt, +article .prompt, +article .nbinput .prompt, +article .jp-InputArea-prompt { + font-family: var(--font-mono, "JetBrains Mono", "Fira Code", monospace); + font-size: 0.75rem; + font-weight: 600; + letter-spacing: 0.02em; + color: var(--primary, #0ea5e9); + padding-top: 0.25rem; +} + +article .input_prompt::after, +article .prompt.input_prompt::after, +article .jp-InputArea-prompt::after { + content: " ➜"; + opacity: 0.6; +} + +/* Output area */ +article .cell .output_wrapper, +article .cell .output, +article .nboutput, +article .jp-OutputArea { + display: block; + padding: 1.35rem 1.75rem; + background: var(--card, #ffffff); +} + +article .nboutput .prompt, +article .jp-OutputArea-prompt { + font-family: var(--font-mono, "JetBrains Mono", monospace); + font-size: 0.75rem; + font-weight: 600; + color: var(--primary, #0ea5e9); + opacity: 0.75; + margin-bottom: 0.75rem; +} + +article .nboutput pre, +article .jp-OutputArea pre { + background: rgba(15, 23, 42, 0.05); + border-radius: var(--radius-md, 0.5rem); + padding: 1rem 1.25rem; + margin: 0; +} + +article .nboutput table, +article .jp-OutputArea table { + width: 100%; + margin: 0.5rem 0 0; + border-collapse: collapse; + font-size: 0.875rem; +} + +article .nboutput table th, +article .jp-OutputArea table th, +article .nboutput table td, +article .jp-OutputArea table td { + border: 1px solid rgba(148, 163, 184, 0.25); + padding: 0.5rem 0.75rem; + text-align: left; +} + +/* Error outputs */ +article .nboutput.error, +article .jp-OutputArea[data-mime-type*="error"], +article .cell .output.stderr { + border-left: 3px solid #ef4444; + background: rgba(248, 113, 113, 0.12); + color: #991b1b; +} + +.dark article .nboutput.error, +.dark article .jp-OutputArea[data-mime-type*="error"], +.dark article .cell .output.stderr, +.dark article .nboutput.error, +.dark article .jp-OutputArea[data-mime-type*="error"], +.dark article .cell .output.stderr { + border-left-color: #fca5a5; + background: rgba(248, 113, 113, 0.21); + color: #fecaca; +} + +/* Markdown cells */ +article .cell.markdown, +article .cell.text_cell, +article .jupyter-cell[data-type="markdown"] { + padding: 2.25rem 2.5rem; + background: linear-gradient(135deg, rgba(59, 130, 246, 0.06), transparent); +} + +article .cell.markdown p:last-child, +article .cell.markdown ul:last-child, +article .cell.markdown ol:last-child { + margin-bottom: 0; +} + +/* Dark theme tuning */ +.dark article .jupyter-notebook .cell, +.dark article .nb-notebook .cell, +.dark article .jp-Notebook .jp-Cell, +.dark article .jupyter-notebook .cell, +.dark article .nb-notebook .cell, +.dark article .jp-Notebook .jp-Cell { + background: rgba(15, 23, 42, 0.75); + border-color: rgba(148, 163, 184, 0.22); + box-shadow: 0 20px 40px rgba(2, 6, 23, 0.65); +} + +.dark article .cell .input, +.dark article .cell .input, +.dark article .nbinput, +.dark article .nbinput { + background: rgba(148, 163, 184, 0.08); + border-bottom-color: rgba(148, 163, 184, 0.2); +} + +.dark article .nboutput pre, +.dark article .nboutput pre, +.dark article .jp-OutputArea pre, +.dark article .jp-OutputArea pre { + background: rgba(148, 163, 184, 0.12); +} + +.dark article .cell.markdown, +.dark article .cell.markdown, +.dark article .jupyter-cell[data-type="markdown"], +.dark article .jupyter-cell[data-type="markdown"] { + background: linear-gradient(135deg, rgba(14, 165, 233, 0.18), transparent); +} + +/* Responsive tweaks */ +@media (max-width: 768px) { + article .cell .input, + article .nbinput, + article .jp-InputArea { + grid-template-columns: minmax(0, 1fr); + padding: 1.25rem 1.25rem 0.9rem; + } + + article .cell .input pre, + article .nbinput pre, + article .jp-InputArea pre { + font-size: 0.85rem; + } + + article .cell .output_wrapper, + article .nboutput, + article .jp-OutputArea { + padding: 1.15rem 1.25rem; + } + + article .cell::before { + left: 1.25rem; + top: 0.75rem; + } +} diff --git a/docs/stylesheets/mermaid.css b/docs/stylesheets/mermaid.css new file mode 100644 index 00000000..f6425914 --- /dev/null +++ b/docs/stylesheets/mermaid.css @@ -0,0 +1,108 @@ +/* Mermaid diagram styling for OpenJudge docs */ +/* Aligns diagrams, adds padding, and keeps them readable across themes */ + +article .mermaid { + position: relative; + display: block; + margin: 2rem auto; + padding: 1.5rem; + border: 1px solid var(--border, rgba(148, 163, 184, 0.4)); + border-radius: var(--radius-lg, 0.75rem); + background: var(--card, #ffffff); + box-shadow: 0 12px 24px rgba(15, 23, 42, 0.06); + overflow-x: auto; + overflow-y: hidden; + text-align: center; + max-width: min(100%, 68rem); + scrollbar-width: thin; +} + +article .mermaid::-webkit-scrollbar { + height: 8px; +} + +article .mermaid::-webkit-scrollbar-thumb { + border-radius: 999px; + background: rgba(148, 163, 184, 0.45); +} + +article .mermaid::-webkit-scrollbar-track { + background: transparent; +} + +article .mermaid svg, +article .mermaid > svg { + display: inline-block; + width: auto; + max-width: none; + color: inherit; +} + +article .mermaid text { + font-family: var(--font-sans, "Inter", "Manrope", -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif); + font-size: 0.95rem; + fill: var(--foreground, #0f172a); +} + +article .mermaid .label { + color: var(--foreground, #0f172a); +} + +article .mermaid a { + color: var(--primary-foreground, var(--primary, #0284c7)); +} + +article figure.mermaid { + margin: 2rem auto; +} + +article figure.mermaid figcaption { + margin-top: 1rem; + font-size: 0.875rem; + color: var(--muted-foreground, #475569); + text-align: center; +} + +/* Dark theme adjustments */ +.dark article .mermaid, +.dark article .mermaid { + border-color: rgba(148, 163, 184, 0.2); + background: rgba(15, 23, 42, 0.55); + box-shadow: 0 12px 28px rgba(15, 23, 42, 0.65); +} + +.dark article .mermaid text, +.dark article .mermaid text { + fill: var(--muted-foreground, #e2e8f0); +} + +.dark article .mermaid .label, +.dark article .mermaid .label { + color: var(--muted-foreground, #e2e8f0); +} + +.dark article figure.mermaid figcaption, +.dark article figure.mermaid figcaption { + color: var(--muted-foreground, #cbd5f5); +} + +/* Hide raw fenced code only when Mermaid successfully renders */ +/* If .mermaid exists as a sibling, the diagram rendered successfully */ +article .mermaid ~ pre code.language-mermaid, +article .mermaid + pre code.language-mermaid { + display: none; +} + +/* Alternative: if Mermaid wraps the pre, hide the pre entirely */ +article .mermaid pre { + display: none; +} + +/* Responsive tweaks */ +@media (max-width: 640px) { + article .mermaid { + margin: 1.5rem -1rem; + padding: 1.25rem; + border-radius: var(--radius-md, 0.5rem); + } +} diff --git a/docs/stylesheets/mkdocstrings.css b/docs/stylesheets/mkdocstrings.css new file mode 100644 index 00000000..77b042c6 --- /dev/null +++ b/docs/stylesheets/mkdocstrings.css @@ -0,0 +1,325 @@ +/* mkdocstrings API documentation styling */ +/* Brings structured cards, signatures, and definition lists inline with OpenJudge visuals */ +/* + * IMPORTANT: These selectors target the default mkdocstrings-python template classes: + * - .doc, .doc-object, .doc-heading, .doc-signature, .doc-contents + * + * If you've customized your mkdocstrings template or upgraded to a version with + * different class names, you may need to adjust these selectors. + * + * Test with: mkdocs build && check generated API pages for matching classes + */ + +article .mkdocstrings, +article .doc.doc-object { + display: block; + margin: 2.5rem 0; +} + +article .mkdocstrings .doc, +article .doc.doc-object { + position: relative; + margin: 2.25rem 0; + border: 1px solid var(--border, rgba(148, 163, 184, 0.32)); + border-radius: var(--radius-xl, 1rem); + background: var(--card, #ffffff); + box-shadow: 0 24px 44px rgba(15, 23, 42, 0.08); + overflow: hidden; +} + +article .mkdocstrings .doc .doc-heading, +article .doc.doc-object .doc-heading { + display: flex; + align-items: baseline; + justify-content: space-between; + gap: 1rem; + padding: 1.75rem 2rem 1.25rem; + background: linear-gradient(135deg, rgba(14, 165, 233, 0.12), transparent); + border-bottom: 1px solid rgba(148, 163, 184, 0.2); +} + +article .mkdocstrings .doc .doc-heading h2, +article .mkdocstrings .doc .doc-heading h3, +article .doc.doc-object .doc-heading h2, +article .doc.doc-object .doc-heading h3 { + margin: 0; + font-size: clamp(1.25rem, 2.5vw, 1.75rem); + font-weight: 700; + color: var(--foreground, #0f172a); +} + +article .mkdocstrings .doc .doc-heading .doc-link, +article .doc.doc-object .doc-heading .doc-link { + font-size: 0.85rem; + font-weight: 500; + color: var(--primary, #0284c7); + text-decoration: none; + opacity: 0.85; +} + +article .mkdocstrings .doc .doc-heading .doc-link:hover, +article .doc.doc-object .doc-heading .doc-link:hover { + text-decoration: underline; + opacity: 1; +} + +/* Object signature */ +article .mkdocstrings .doc .doc-signature, +article .doc.doc-object .doc-signature, +article .mkdocstrings .doc pre.docstring-signature, +article .doc.doc-object pre.docstring-signature { + margin: 0; + padding: 1.25rem 2rem; + background: rgba(15, 23, 42, 0.05); + border-bottom: 1px solid rgba(148, 163, 184, 0.18); + overflow-x: auto; + font-family: var(--font-mono, "JetBrains Mono", "Fira Code", monospace); + font-size: 0.9rem; +} + +article .mkdocstrings .doc .doc-signature code, +article .doc.doc-object .doc-signature code, +article .mkdocstrings .doc pre.docstring-signature code, +article .doc.doc-object pre.docstring-signature code { + background: transparent; + padding: 0; + font-size: inherit; +} + +/* Docstring content */ +article .mkdocstrings .doc .doc-contents, +article .doc.doc-object .doc-contents { + padding: 1.75rem 2rem 2.25rem; + display: grid; + gap: 1.75rem; +} + +article .mkdocstrings .doc .doc-contents > p:first-child { + font-size: 1rem; + line-height: 1.8; + color: var(--muted-foreground, #475569); +} + +/* Definition lists (Parameters, Returns, etc.) */ +article .mkdocstrings dl, +article .doc.doc-object dl { + margin: 0; + padding: 1.25rem 1.5rem; + border: 1px solid rgba(148, 163, 184, 0.25); + border-radius: var(--radius-lg, 0.75rem); + background: rgba(148, 163, 184, 0.08); + display: grid; + gap: 0.85rem; +} + +article .mkdocstrings dl dt, +article .doc.doc-object dl dt { + font-family: var(--font-mono, "JetBrains Mono", monospace); + font-size: 0.85rem; + font-weight: 600; + color: var(--foreground, #0f172a); + display: flex; + align-items: baseline; + gap: 0.5rem; +} + +article .mkdocstrings dl dt .name, +article .doc.doc-object dl dt .name { + padding: 0.1rem 0.45rem; + border-radius: 0.45rem; + background: rgba(14, 165, 233, 0.15); + color: var(--primary-foreground, #0f172a); +} + +article .mkdocstrings dl dt .type, +article .doc.doc-object dl dt .type { + font-size: 0.76rem; + color: var(--muted-foreground, #475569); +} + +article .mkdocstrings dl dd, +article .doc.doc-object dl dd { + margin-left: 0; + font-size: 0.95rem; + color: var(--muted-foreground, #475569); + line-height: 1.7; +} + +article .mkdocstrings dl dd > :last-child { + margin-bottom: 0; +} + +/* Members tables */ +article .mkdocstrings table, +article .doc.doc-object table { + width: 100%; + border-collapse: collapse; + margin: 1rem 0 0; + font-size: 0.95rem; +} + +article .mkdocstrings table th, +article .doc.doc-object table th, +article .mkdocstrings table td, +article .doc.doc-object table td { + border: 1px solid rgba(148, 163, 184, 0.2); + padding: 0.65rem 0.85rem; + text-align: left; +} + +article .mkdocstrings table tr:nth-child(even), +article .doc.doc-object table tr:nth-child(even) { + background: rgba(148, 163, 184, 0.12); +} + +/* Collapsible members */ +article details.doc-section, +article .mkdocstrings details { + border: 1px solid rgba(148, 163, 184, 0.24); + border-radius: var(--radius-lg, 0.75rem); + background: rgba(15, 23, 42, 0.03); + padding: 1rem 1.35rem; +} + +article details.doc-section summary, +article .mkdocstrings details summary { + cursor: pointer; + font-weight: 600; + color: var(--foreground, #0f172a); +} + +article .mkdocstrings details[open] { + background: rgba(14, 165, 233, 0.08); +} + +/* Source buttons */ +article .mkdocstrings .doc .view-source, +article .doc.doc-object .view-source { + position: absolute; + top: 0.85rem; + right: 1.25rem; + display: inline-flex; + align-items: center; + gap: 0.35rem; + font-size: 0.8rem; + font-weight: 600; + color: var(--primary, #0284c7); + text-decoration: none; + padding: 0.45rem 0.65rem; + border-radius: 999px; + background: rgba(14, 165, 233, 0.15); + transition: transform 0.15s ease, box-shadow 0.15s ease; +} + +article .mkdocstrings .doc .view-source:hover, +article .doc.doc-object .view-source:hover { + transform: translateY(-1px); + box-shadow: 0 8px 20px rgba(14, 165, 233, 0.3); +} + +/* Dark theme adjustments */ +.dark article .mkdocstrings .doc, +.dark article .doc.doc-object, +.dark article .mkdocstrings .doc, +.dark article .doc.doc-object { + background: rgba(15, 23, 42, 0.82); + border-color: rgba(148, 163, 184, 0.18); + box-shadow: 0 28px 60px rgba(2, 6, 23, 0.75); +} + +.dark article .mkdocstrings .doc .doc-heading, +.dark article .mkdocstrings .doc .doc-heading { + background: linear-gradient(135deg, rgba(14, 165, 233, 0.22), transparent); + border-bottom-color: rgba(148, 163, 184, 0.25); +} + +.dark article .mkdocstrings dl, +.dark article .mkdocstrings dl { + background: rgba(148, 163, 184, 0.16); + border-color: rgba(148, 163, 184, 0.32); +} + +.dark article .mkdocstrings dl dt, +.dark article .mkdocstrings dl dt { + color: #e2e8f0; +} + +.dark article .mkdocstrings dl dt .type, +.dark article .mkdocstrings dl dt .type { + color: rgba(226, 232, 240, 0.74); +} + +.dark article .mkdocstrings dl dd, +.dark article .mkdocstrings dl dd { + color: rgba(226, 232, 240, 0.78); +} + +.dark article .mkdocstrings table tr:nth-child(even), +.dark article .mkdocstrings table tr:nth-child(even) { + background: rgba(148, 163, 184, 0.18); +} + +/* Responsive adjustments */ +@media (max-width: 768px) { + article .mkdocstrings .doc, + article .doc.doc-object { + margin: 1.75rem -0.75rem; + border-radius: var(--radius-lg, 0.75rem); + } + + article .mkdocstrings .doc .doc-heading, + article .doc.doc-object .doc-heading { + padding: 1.5rem 1.75rem; + flex-direction: column; + align-items: flex-start; + gap: 0.6rem; + } + + article .mkdocstrings .doc .doc-signature, + article .doc.doc-object .doc-signature, + article .mkdocstrings .doc pre.docstring-signature, + article .doc.doc-object pre.docstring-signature { + padding: 1rem 1.5rem; + } + + article .mkdocstrings .doc .doc-contents, + article .doc.doc-object .doc-contents { + padding: 1.5rem 1.5rem 1.9rem; + } +} + +/* ======================================== + Fallback Styles for Generic API Docs + (if mkdocstrings classes are unavailable) + ======================================== */ + +/* Generic API section styling - applies to any .api-doc container */ +article .api-doc, +article [class*="api-"], +article [class*="autodoc"] { + margin: 2rem 0; + border: 1px solid var(--border, rgba(148, 163, 184, 0.3)); + border-radius: var(--radius-lg, 0.75rem); + background: var(--card, #ffffff); + padding: 1.5rem; +} + +/* Generic function/class signature in monospace */ +article .signature, +article [class*="sig"], +article code.signature { + font-family: var(--font-mono, "JetBrains Mono", monospace); + font-size: 0.9rem; + display: block; + padding: 0.75rem 1rem; + background: rgba(15, 23, 42, 0.05); + border-radius: var(--radius-md, 0.5rem); + overflow-x: auto; +} + +.dark article .signature, +.dark article [class*="sig"], +.dark article .signature, +.dark article [class*="sig"] { + background: rgba(148, 163, 184, 0.12); +} diff --git a/docs/stylesheets/nav-scroll-fix.css b/docs/stylesheets/nav-scroll-fix.css new file mode 100644 index 00000000..c4fefc16 --- /dev/null +++ b/docs/stylesheets/nav-scroll-fix.css @@ -0,0 +1,482 @@ +/* + * Navigation & Scroll Enhancements for OpenJudge Documentation + * Phase 1: 导航滚动修复 + * + * Features: + * - 侧边栏滚动优化 + * - 当前页面高亮 + * - 滚动时侧边栏固定 + * - 目录 (TOC) 滚动跟随 + * - 平滑滚动 + */ + +/* ======================================== + Global Smooth Scroll + ======================================== */ + +html { + scroll-behavior: smooth; +} + +/* Respect reduced motion preference */ +@media (prefers-reduced-motion: reduce) { + html { + scroll-behavior: auto; + } +} + +/* ======================================== + Sidebar Navigation + ======================================== */ + +/* Sidebar container - sticky positioning */ +nav.sidebar, +.md-sidebar, +.nav-sidebar, +aside.sidebar { + position: sticky; + top: 0; + max-height: 100vh; + overflow-y: auto; + overflow-x: hidden; + /* Hide scrollbar by default, show on hover */ + scrollbar-width: none; + scrollbar-color: var(--muted-foreground, #d1d5db) transparent; +} + +/* Show scrollbar on hover (Firefox) */ +nav.sidebar:hover, +.md-sidebar:hover, +.nav-sidebar:hover, +aside.sidebar:hover { + scrollbar-width: thin; +} + +/* Custom scrollbar for sidebar (Webkit - hidden by default) */ +nav.sidebar::-webkit-scrollbar, +.md-sidebar::-webkit-scrollbar, +.nav-sidebar::-webkit-scrollbar, +aside.sidebar::-webkit-scrollbar { + width: 0; +} + +/* Show scrollbar on hover (Webkit) */ +nav.sidebar:hover::-webkit-scrollbar, +.md-sidebar:hover::-webkit-scrollbar, +.nav-sidebar:hover::-webkit-scrollbar, +aside.sidebar:hover::-webkit-scrollbar { + width: 4px; +} + +nav.sidebar::-webkit-scrollbar-track, +.md-sidebar::-webkit-scrollbar-track, +.nav-sidebar::-webkit-scrollbar-track, +aside.sidebar::-webkit-scrollbar-track { + background: transparent; +} + +nav.sidebar::-webkit-scrollbar-thumb, +.md-sidebar::-webkit-scrollbar-thumb, +.nav-sidebar::-webkit-scrollbar-thumb, +aside.sidebar::-webkit-scrollbar-thumb { + background: var(--muted-foreground, #d1d5db); + border-radius: 2px; +} + +nav.sidebar::-webkit-scrollbar-thumb:hover, +.md-sidebar::-webkit-scrollbar-thumb:hover, +.nav-sidebar::-webkit-scrollbar-thumb:hover, +aside.sidebar::-webkit-scrollbar-thumb:hover { + background: var(--foreground, #9ca3af); +} + +/* ======================================== + Navigation Links + ======================================== */ + +/* Base nav link styles */ +nav.sidebar a, +.md-sidebar a, +.nav-sidebar a, +aside.sidebar a, +.md-nav__link { + display: block; + padding: 0.5rem 0.75rem; + color: var(--muted-foreground, #6b7280); + text-decoration: none; + border-radius: 0.375rem; + transition: all 0.15s ease; + font-size: 0.875rem; + line-height: 1.5; +} + +/* Hover state */ +nav.sidebar a:hover, +.md-sidebar a:hover, +.nav-sidebar a:hover, +aside.sidebar a:hover, +.md-nav__link:hover { + color: var(--foreground, #1f2937); + background: var(--muted, rgba(0, 0, 0, 0.04)); +} + +/* ======================================== + Current Page Highlight + ======================================== */ + +/* Active/current page indicator */ +nav.sidebar a.active, +nav.sidebar a[aria-current="page"], +.md-sidebar a.active, +.md-sidebar a[aria-current="page"], +.nav-sidebar a.active, +.nav-sidebar a[aria-current="page"], +aside.sidebar a.active, +aside.sidebar a[aria-current="page"], +.md-nav__link--active, +.md-nav__item--active > .md-nav__link { + color: var(--primary, #3b82f6); + background: rgba(59, 130, 246, 0.1); + font-weight: 500; + position: relative; +} + +/* Active indicator bar */ +nav.sidebar a.active::before, +nav.sidebar a[aria-current="page"]::before, +.md-sidebar a.active::before, +.md-sidebar a[aria-current="page"]::before, +.nav-sidebar a.active::before, +aside.sidebar a.active::before, +.md-nav__link--active::before { + content: ''; + position: absolute; + left: 0; + top: 50%; + transform: translateY(-50%); + width: 3px; + height: 1.25rem; + background: var(--primary, #3b82f6); + border-radius: 0 2px 2px 0; +} + +/* ======================================== + Table of Contents (TOC) + ======================================== */ + +/* TOC container */ +.toc, +.md-sidebar--secondary, +.table-of-contents, +nav.toc { + position: sticky; + top: 1rem; + max-height: calc(100vh - 2rem); + overflow-y: auto; + padding-right: 0.5rem; +} + +/* TOC title */ +.toc-title, +.md-sidebar--secondary .md-nav__title, +.table-of-contents-title { + font-size: 0.75rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + color: var(--muted-foreground, #9ca3af); + margin-bottom: 0.75rem; + padding: 0 0.5rem; +} + +/* TOC links */ +.toc a, +.md-sidebar--secondary a, +.table-of-contents a, +nav.toc a { + display: block; + padding: 0.375rem 0.5rem; + font-size: 0.8125rem; + color: var(--muted-foreground, #6b7280); + text-decoration: none; + border-left: 2px solid transparent; + transition: all 0.15s ease; + line-height: 1.4; +} + +.toc a:hover, +.md-sidebar--secondary a:hover, +.table-of-contents a:hover, +nav.toc a:hover { + color: var(--foreground, #1f2937); + border-left-color: var(--muted-foreground, #d1d5db); +} + +/* Active TOC item (scroll spy) */ +.toc a.active, +.toc a[aria-current="true"], +.md-sidebar--secondary a.active, +.table-of-contents a.active, +nav.toc a.active { + color: var(--primary, #3b82f6); + border-left-color: var(--primary, #3b82f6); + font-weight: 500; +} + +/* Nested TOC levels */ +.toc ul ul a, +.md-sidebar--secondary .md-nav--secondary a, +.table-of-contents ul ul a { + padding-left: 1rem; + font-size: 0.75rem; +} + +.toc ul ul ul a, +.table-of-contents ul ul ul a { + padding-left: 1.5rem; +} + +/* ======================================== + Scroll Progress Indicator + ======================================== */ + +.scroll-progress { + position: fixed; + top: 0; + left: 0; + width: 0%; + height: 2px; + background: var(--primary, #3b82f6); + z-index: 9999; + transition: width 0.1s ease-out; +} + +/* ======================================== + Scroll to Top Button + ======================================== */ + +.scroll-to-top { + position: fixed; + bottom: 2rem; + right: 2rem; + width: 2.5rem; + height: 2.5rem; + display: flex; + align-items: center; + justify-content: center; + background: var(--background, #fff); + border: 1px solid var(--border, #e5e7eb); + border-radius: 50%; + color: var(--muted-foreground, #6b7280); + cursor: pointer; + opacity: 0; + visibility: hidden; + transition: all 0.2s ease; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + z-index: 100; +} + +.scroll-to-top.visible { + opacity: 1; + visibility: visible; +} + +.scroll-to-top:hover { + color: var(--foreground, #1f2937); + border-color: var(--primary, #3b82f6); + transform: translateY(-2px); + box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15); +} + +/* ======================================== + Header Offset for Anchor Links + ======================================== */ + +/* Offset for fixed header when jumping to anchors */ +:target::before { + content: ''; + display: block; + height: 80px; + margin-top: -80px; + visibility: hidden; + pointer-events: none; +} + +/* Alternative using scroll-margin */ +h1[id], h2[id], h3[id], h4[id], h5[id], h6[id], +[id]:target { + scroll-margin-top: 80px; +} + +/* ======================================== + Dark Mode + ======================================== */ + +/* Dark mode scrollbar */ +.dark nav.sidebar::-webkit-scrollbar-thumb, +.dark .md-sidebar::-webkit-scrollbar-thumb, +.dark .nav-sidebar::-webkit-scrollbar-thumb, +.dark aside.sidebar::-webkit-scrollbar-thumb, +.dark nav.sidebar::-webkit-scrollbar-thumb, +.dark .md-sidebar::-webkit-scrollbar-thumb { + background: var(--muted-foreground, #4b5563); +} + +/* Dark mode nav links */ +.dark nav.sidebar a, +.dark .md-sidebar a, +.dark .nav-sidebar a, +.dark aside.sidebar a, +.dark nav.sidebar a, +.dark .md-sidebar a { + color: var(--muted-foreground, #9ca3af); +} + +.dark nav.sidebar a:hover, +.dark .md-sidebar a:hover, +.dark .nav-sidebar a:hover, +.dark aside.sidebar a:hover, +.dark nav.sidebar a:hover, +.dark .md-sidebar a:hover { + color: var(--foreground, #f3f4f6); + background: rgba(255, 255, 255, 0.05); +} + +/* Dark mode active state */ +.dark nav.sidebar a.active, +.dark nav.sidebar a[aria-current="page"], +.dark .md-sidebar a.active, +.dark .md-nav__link--active, +.dark nav.sidebar a.active, +.dark .md-sidebar a.active { + color: var(--primary, #60a5fa); + background: rgba(96, 165, 250, 0.1); +} + +/* Dark mode TOC */ +.dark .toc a, +.dark .md-sidebar--secondary a, +.dark .table-of-contents a, +.dark .toc a, +.dark .md-sidebar--secondary a { + color: var(--muted-foreground, #9ca3af); +} + +.dark .toc a:hover, +.dark .md-sidebar--secondary a:hover, +.dark .toc a:hover, +.dark .md-sidebar--secondary a:hover { + color: var(--foreground, #f3f4f6); + border-left-color: var(--muted-foreground, #6b7280); +} + +.dark .toc a.active, +.dark .md-sidebar--secondary a.active, +.dark .toc a.active, +.dark .md-sidebar--secondary a.active { + color: var(--primary, #60a5fa); + border-left-color: var(--primary, #60a5fa); +} + +/* Dark mode scroll to top */ +.dark .scroll-to-top, +.dark .scroll-to-top { + background: var(--background, #1f2937); + border-color: var(--border, #374151); + color: var(--muted-foreground, #9ca3af); +} + +.dark .scroll-to-top:hover, +.dark .scroll-to-top:hover { + color: var(--foreground, #f3f4f6); + border-color: var(--primary, #60a5fa); +} + +/* ======================================== + Mobile Navigation + ======================================== */ + +@media (max-width: 768px) { + /* Mobile sidebar - only apply custom positioning if sidebar has .mobile-drawer class */ + nav.sidebar.mobile-drawer, + .md-sidebar.mobile-drawer, + .nav-sidebar.mobile-drawer, + aside.sidebar.mobile-drawer { + position: fixed; + top: 0; + left: 0; + width: 280px; + height: 100vh; + max-height: 100vh; + transform: translateX(-100%); + transition: transform 0.3s ease; + z-index: 1000; + background: var(--background, #fff); + border-right: 1px solid var(--border, #e5e7eb); + padding: 1rem; + } + + /* Open state for drawer navigation */ + nav.sidebar.mobile-drawer.open, + .md-sidebar.mobile-drawer.open, + .nav-sidebar.mobile-drawer.open, + aside.sidebar.mobile-drawer.open { + transform: translateX(0); + } + + /* Mobile TOC - hidden by default */ + .toc, + .md-sidebar--secondary, + .table-of-contents { + display: none; + } + + /* Scroll to top - smaller on mobile */ + .scroll-to-top { + bottom: 1rem; + right: 1rem; + width: 2.25rem; + height: 2.25rem; + } + + /* Reduce scroll margin for smaller header */ + h1[id], h2[id], h3[id], h4[id], h5[id], h6[id], + [id]:target { + scroll-margin-top: 60px; + } +} + +/* Dark mode mobile sidebar */ +@media (max-width: 768px) { + .dark nav.sidebar.mobile-drawer, + .dark .md-sidebar.mobile-drawer, + .dark .nav-sidebar.mobile-drawer, + .dark aside.sidebar.mobile-drawer, + .dark nav.sidebar.mobile-drawer, + .dark .md-sidebar.mobile-drawer { + background: var(--background, #111827); + border-right-color: var(--border, #374151); + } +} + +/* ======================================== + Overlay for Mobile Menu + ======================================== */ + +.nav-overlay { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + opacity: 0; + visibility: hidden; + transition: all 0.3s ease; + z-index: 999; +} + +.nav-overlay.visible { + opacity: 1; + visibility: visible; +} diff --git a/docs/stylesheets/readability-enhancements.css b/docs/stylesheets/readability-enhancements.css new file mode 100644 index 00000000..19865c41 --- /dev/null +++ b/docs/stylesheets/readability-enhancements.css @@ -0,0 +1,253 @@ +/* Readability enhancements */ + +/* ======================================== + Collapsible Sections (Details/Summary) + ======================================== */ + +/* Details container base styles */ +article details, +.prose details, +.md-typeset details { + margin: 1.5rem 0; + padding: 0; + border: 1px solid var(--border, #e5e7eb); + border-radius: 0.5rem; + background: var(--card, #ffffff); + overflow: hidden; + box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.05); +} + +/* Dark mode support */ +.dark article details, +.dark .prose details, +.dark .md-typeset details { + background: var(--card, #1f2937); + border-color: var(--border, #374151); +} + +/* Summary (clickable header) */ +article details summary, +.prose details summary, +.md-typeset details summary { + display: flex; + align-items: center; + gap: 0.75rem; + padding: 0.875rem 1rem; + font-weight: 600; + font-size: 0.95rem; + cursor: pointer; + user-select: none; + list-style: none; + background: var(--muted, #f9fafb); + border-bottom: 1px solid transparent; + transition: all 0.2s ease; +} + +/* Remove default marker */ +article details summary::-webkit-details-marker, +.prose details summary::-webkit-details-marker, +.md-typeset details summary::-webkit-details-marker { + display: none; +} + +article details summary::marker, +.prose details summary::marker, +.md-typeset details summary::marker { + display: none; + content: ""; +} + +/* Chevron icon (using CSS pseudo-element) */ +article details summary::before, +.prose details summary::before, +.md-typeset details summary::before { + content: ""; + display: inline-flex; + align-items: center; + justify-content: center; + width: 1.25rem; + height: 1.25rem; + flex-shrink: 0; + background-color: var(--foreground, #111827); + -webkit-mask-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='20' height='20' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2.5' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpolyline points='9 18 15 12 9 6'%3E%3C/polyline%3E%3C/svg%3E"); + mask-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='20' height='20' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2.5' stroke-linecap='round' stroke-linejoin='round'%3E%3Cpolyline points='9 18 15 12 9 6'%3E%3C/polyline%3E%3C/svg%3E"); + -webkit-mask-repeat: no-repeat; + mask-repeat: no-repeat; + -webkit-mask-position: center; + mask-position: center; + -webkit-mask-size: contain; + mask-size: contain; + transition: transform 0.25s cubic-bezier(0.4, 0, 0.2, 1); +} + +/* Rotate chevron when open */ +article details[open] summary::before, +.prose details[open] summary::before, +.md-typeset details[open] summary::before { + transform: rotate(90deg); +} + +/* Summary hover state */ +article details summary:hover, +.prose details summary:hover, +.md-typeset details summary:hover { + background: var(--accent, #f3f4f6); +} + +/* Dark mode summary */ +.dark article details summary, +.dark .prose details summary, +.dark .md-typeset details summary { + background: var(--muted, #1f2937); +} + +.dark article details summary::before, +.dark .prose details summary::before, +.dark .md-typeset details summary::before { + background-color: var(--foreground, #f9fafb); +} + +.dark article details summary:hover, +.dark .prose details summary:hover, +.dark .md-typeset details summary:hover { + background: var(--accent, #374151); +} + +/* Open state: show border under summary */ +article details[open] summary, +.prose details[open] summary, +.md-typeset details[open] summary { + border-bottom-color: var(--border, #e5e7eb); +} + +.dark article details[open] summary, +.dark .prose details[open] summary, +.dark .md-typeset details[open] summary { + border-bottom-color: var(--border, #374151); +} + +/* Details content area */ +article details > *:not(summary), +.prose details > *:not(summary), +.md-typeset details > *:not(summary) { + padding: 0 1rem; +} + +article details > *:not(summary):first-of-type, +.prose details > *:not(summary):first-of-type, +.md-typeset details > *:not(summary):first-of-type { + padding-top: 1rem; +} + +article details > *:not(summary):last-child, +.prose details > *:not(summary):last-child, +.md-typeset details > *:not(summary):last-child { + padding-bottom: 1rem; +} + +/* ======================================== + Admonition-style Details (note, tip, warning, etc.) + For pymdownx.details integration + ======================================== */ + +/* Note style details */ +article details.note summary, +.prose details.note summary, +.md-typeset details.note summary { + background: rgba(59, 130, 246, 0.08); +} + +article details.note summary::before, +.prose details.note summary::before, +.md-typeset details.note summary::before { + background-color: #3b82f6; +} + +/* Tip style details */ +article details.tip summary, +.prose details.tip summary, +.md-typeset details.tip summary { + background: rgba(16, 185, 129, 0.08); +} + +article details.tip summary::before, +.prose details.tip summary::before, +.md-typeset details.tip summary::before { + background-color: #10b981; +} + +/* Warning style details */ +article details.warning summary, +.prose details.warning summary, +.md-typeset details.warning summary { + background: rgba(245, 158, 11, 0.08); +} + +article details.warning summary::before, +.prose details.warning summary::before, +.md-typeset details.warning summary::before { + background-color: #f59e0b; +} + +/* Danger style details */ +article details.danger summary, +.prose details.danger summary, +.md-typeset details.danger summary { + background: rgba(239, 68, 68, 0.08); +} + +article details.danger summary::before, +.prose details.danger summary::before, +.md-typeset details.danger summary::before { + background-color: #ef4444; +} + +/* Info style details */ +article details.info summary, +.prose details.info summary, +.md-typeset details.info summary { + background: rgba(6, 182, 212, 0.08); +} + +article details.info summary::before, +.prose details.info summary::before, +.md-typeset details.info summary::before { + background-color: #06b6d4; +} + +/* Example style details */ +article details.example summary, +.prose details.example summary, +.md-typeset details.example summary { + background: rgba(139, 92, 246, 0.08); +} + +article details.example summary::before, +.prose details.example summary::before, +.md-typeset details.example summary::before { + background-color: #8b5cf6; +} + +/* ======================================== + Focus styles for accessibility + ======================================== */ + +article details summary:focus, +.prose details summary:focus, +.md-typeset details summary:focus { + outline: 2px solid var(--ring, #3b82f6); + outline-offset: 2px; +} + +article details summary:focus:not(:focus-visible), +.prose details summary:focus:not(:focus-visible), +.md-typeset details summary:focus:not(:focus-visible) { + outline: none; +} + +article details summary:focus-visible, +.prose details summary:focus-visible, +.md-typeset details summary:focus-visible { + outline: 2px solid var(--ring, #3b82f6); + outline-offset: 2px; +} diff --git a/docs/stylesheets/responsive.css b/docs/stylesheets/responsive.css new file mode 100644 index 00000000..251965d7 --- /dev/null +++ b/docs/stylesheets/responsive.css @@ -0,0 +1,1083 @@ + +article p { + font-size: 19px; +} + +article ul>li { + font-size: 19px; + margin-top: calc(var(--spacing) * 2); +} + +.card-desc { + margin: 0; + font-size: 15px; + opacity: 0.8; + line-height: 1.6; + pointer-events: none; +} +.card-header { + display: inline-flex !important; + align-items: center !important; + flex-wrap: nowrap !important; + font-size: 19px; + margin-bottom: 12px; + white-space: nowrap; + pointer-events: none; +} + + +/* 修改链接样式 */ +a[data-sidebar="menu-button"] { + font-size: 17px; + font-weight: 300; + margin-bottom: 12px; +} + +.pr-2 { + font-size: 20px; +} + +.text-muted-foreground { + font-size: 15px; + color: var(--color-muted-foreground); +} + +article .typography h2 { + font-size: 50px; + font-weight: 500; + margin-top: calc(var(--spacing) * 20); +} + + +article .typography h2 { + font-size: 30px; + font-weight: 1000; + margin-top: calc(var(--spacing) * 20); +} + +/* 如果你可以添加自定义类 */ +[data-slot="sidebar-menu-button"] { + white-space: nowrap !important; +} + +/* 方法3:如果上面不生效,使用更具体的选择器 */ +div[data-slot="sidebar-content"].no-scrollbar { + scrollbar-width: thin !important; +} + +div[data-slot="sidebar-content"].no-scrollbar::-webkit-scrollbar { + display: block !important; + width: 1px !important; +} + +.sidebar .text-sidebar-foreground, +div.text-sidebar-foreground.sticky { + height: calc(100svh - var(--header-height) - 20px) !important; +} + +/* + * Responsive Enhancements for OpenJudge Documentation + * Phase 5: 响应式完善 + * + * Features: + * - 统一断点系统 + * - 移动端导航优化 + * - 触摸友好交互 + * - 多设备布局适配 + * - 响应式工具类 + * - 打印样式优化 + */ + +/* ======================================== + Breakpoint System (CSS Custom Properties) + ======================================== */ + +:root { + /* Breakpoint values (for reference in media queries) */ + /* --breakpoint-xs: 0px; Mobile portrait */ + /* --breakpoint-sm: 640px; Mobile landscape */ + /* --breakpoint-md: 768px; Tablet portrait */ + /* --breakpoint-lg: 1024px; Tablet landscape / Small desktop */ + /* --breakpoint-xl: 1280px; Desktop */ + /* --breakpoint-2xl: 1536px; Large desktop */ + + /* Container max-widths */ + --container-sm: 640px; + --container-md: 768px; + --container-lg: 1024px; + --container-xl: 1280px; + --container-2xl: 1536px; + + /* Responsive spacing */ + --spacing-mobile: 1rem; + --spacing-tablet: 1.5rem; + --spacing-desktop: 2rem; + + /* Touch target minimum size */ + --touch-target-min: 44px; +} + +/* ======================================== + Base Responsive Container + ======================================== */ + +.container, +article, +.md-content__inner, +.prose { + width: 100%; + margin-left: auto; + margin-right: auto; + padding-left: var(--spacing-mobile); + padding-right: var(--spacing-mobile); +} + +@media (min-width: 640px) { + .container, + article, + .md-content__inner, + .prose { + padding-left: var(--spacing-tablet); + padding-right: var(--spacing-tablet); + } +} + +@media (min-width: 1024px) { + .container, + article, + .md-content__inner, + .prose { + padding-left: var(--spacing-desktop); + padding-right: var(--spacing-desktop); + max-width: var(--container-lg); + } +} + +@media (min-width: 1280px) { + .container, + article, + .md-content__inner, + .prose { + max-width: var(--container-xl); + } +} + +/* ======================================== + Mobile First Base Styles (< 640px) + ======================================== */ + +/* Typography scaling */ +html { + font-size: 15px; +} + +@media (min-width: 640px) { + html { + font-size: 16px; + } +} + +/* Mobile layout adjustments */ +@media (max-width: 639px) { + /* Main content full width */ + .md-main__inner, + main.md-main { + padding: 0; + } + + /* Reduce margins on mobile */ + article > *, + .prose > *, + .md-typeset > * { + margin-left: 0; + margin-right: 0; + } + + /* Stack grids on mobile */ + .grid, + .md-grid { + display: block; + } + + .grid > *, + .md-grid > * { + width: 100%; + margin-bottom: 1rem; + } + + /* Hide TOC on mobile */ + .md-sidebar--secondary, + .toc, + .table-of-contents { + display: none; + } + + /* Full width code blocks */ + pre, + .highlight, + .codehilite { + margin-left: calc(-1 * var(--spacing-mobile)); + margin-right: calc(-1 * var(--spacing-mobile)); + border-radius: 0; + border-left: none; + border-right: none; + } + + pre code { + padding-left: var(--spacing-mobile); + padding-right: var(--spacing-mobile); + } + + /* Full width tables */ + .table-responsive, + table { + margin-left: calc(-1 * var(--spacing-mobile)); + margin-right: calc(-1 * var(--spacing-mobile)); + width: calc(100% + 2 * var(--spacing-mobile)); + } + + /* Reduce heading sizes on mobile */ + h1 { font-size: 1.75rem; } + h2 { font-size: 1.375rem; } + h3 { font-size: 1.125rem; } + h4 { font-size: 1rem; } +} + +/* ======================================== + Tablet Styles (640px - 1023px) + ======================================== */ + +@media (min-width: 640px) and (max-width: 1023px) { + /* Two column layout for larger tablets */ + .md-main__inner { + display: flex; + flex-wrap: wrap; + } + + /* Sidebar takes full width on tablet portrait */ + .md-sidebar--primary { + width: 100%; + max-width: none; + position: relative; + height: auto; + max-height: none; + } + + /* Content takes full width */ + .md-content { + width: 100%; + max-width: none; + } + + /* Hide secondary sidebar on tablet */ + .md-sidebar--secondary { + display: none; + } + + /* Grid adjustments */ + .grid-cols-3, + .md-grid-3 { + grid-template-columns: repeat(2, 1fr); + } + + .grid-cols-4, + .md-grid-4 { + grid-template-columns: repeat(2, 1fr); + } +} + +/* ======================================== + Desktop Styles (1024px+) + ======================================== */ + +@media (min-width: 1024px) { + /* Three column layout */ + .md-main__inner { + display: flex; + } + + /* Primary sidebar */ + .md-sidebar--primary { + width: 240px; + flex-shrink: 0; + } + + /* Main content */ + .md-content { + flex: 1; + min-width: 0; + } + + /* Secondary sidebar (TOC) */ + .md-sidebar--secondary { + width: 200px; + flex-shrink: 0; + display: block; + } +} + +@media (min-width: 1280px) { + .md-sidebar--primary { + width: 280px; + } + + .md-sidebar--secondary { + width: 240px; + } +} + +/* ======================================== + Touch Friendly Interactions + ======================================== */ + +/* Ensure minimum touch target size */ +@media (hover: none) and (pointer: coarse) { + /* Touch devices */ + + a, + button, + .btn, + .button, + input[type="button"], + input[type="submit"], + .md-nav__link, + .tabbed-labels > label, + details summary { + min-height: var(--touch-target-min); + min-width: var(--touch-target-min); + padding: 0.75rem 1rem; + } + + /* Increase tap targets in navigation */ + .md-nav__link, + nav.sidebar a, + .nav-sidebar a { + padding: 0.875rem 1rem; + } + + /* Larger checkboxes */ + input[type="checkbox"], + input[type="radio"] { + width: 1.25rem; + height: 1.25rem; + } + + /* Remove hover effects on touch - use active instead */ + .highlight:hover, + .codehilite:hover, + pre:hover { + transform: none; + box-shadow: var(--rm-shadow-sm); + } + + .highlight:active, + .codehilite:active, + pre:active { + transform: scale(0.99); + } + + /* Show copy button always on touch devices */ + .copy-button, + .md-clipboard, + button[data-clipboard-target] { + opacity: 1 !important; + } + + /* Disable hover lift effects */ + button:hover, + .button:hover, + .btn:hover { + transform: none; + } + + button:active, + .button:active, + .btn:active { + transform: scale(0.98); + } +} + +/* Hover-capable devices */ +@media (hover: hover) and (pointer: fine) { + /* Enable hover effects */ + .highlight:hover, + .admonition:hover, + .workflow ol > li:hover { + transform: translateY(-2px); + } +} + +/* ======================================== + Mobile Navigation Enhancements + ======================================== */ + +/* Mobile menu toggle button */ +.mobile-menu-toggle { + display: none; + position: fixed; + bottom: 1.5rem; + left: 1.5rem; + width: 3rem; + height: 3rem; + border-radius: 50%; + background: var(--primary, #3b82f6); + color: white; + border: none; + box-shadow: 0 4px 12px rgba(59, 130, 246, 0.4); + cursor: pointer; + z-index: 1001; + transition: all 0.2s ease; +} + +.mobile-menu-toggle:active { + transform: scale(0.95); +} + +@media (max-width: 767px) { + .mobile-menu-toggle { + display: flex; + align-items: center; + justify-content: center; + } +} + +/* Mobile menu icon */ +.mobile-menu-toggle .icon-menu { + width: 1.25rem; + height: 1.25rem; +} + +.mobile-menu-toggle .icon-close { + display: none; + width: 1.25rem; + height: 1.25rem; +} + +.mobile-menu-toggle.active .icon-menu { + display: none; +} + +.mobile-menu-toggle.active .icon-close { + display: block; +} + +/* Mobile sidebar drawer */ +@media (max-width: 767px) { + .md-sidebar--primary, + nav.sidebar, + .nav-sidebar { + position: fixed; + top: 0; + left: 0; + width: 85%; + max-width: 320px; + height: 100vh; + max-height: 100vh; + background: var(--background, #fff); + border-right: 1px solid var(--border, #e5e7eb); + transform: translateX(-100%); + transition: transform 0.3s cubic-bezier(0.4, 0, 0.2, 1); + z-index: 1000; + overflow-y: auto; + -webkit-overflow-scrolling: touch; + padding: 1rem; + padding-top: 2rem; + } + + .md-sidebar--primary.open, + nav.sidebar.open, + .nav-sidebar.open { + transform: translateX(0); + } + + /* Dark mode mobile sidebar */ + .dark .md-sidebar--primary, + .dark nav.sidebar, + .dark .md-sidebar--primary { + background: var(--background, #111827); + border-right-color: var(--border, #374151); + } +} + +/* Mobile navigation overlay */ +.mobile-nav-overlay { + display: none; + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + backdrop-filter: blur(4px); + -webkit-backdrop-filter: blur(4px); + z-index: 999; + opacity: 0; + transition: opacity 0.3s ease; +} + +@media (max-width: 767px) { + .mobile-nav-overlay { + display: block; + pointer-events: none; + } + + .mobile-nav-overlay.visible { + opacity: 1; + pointer-events: auto; + } +} + +/* ======================================== + Responsive Images + ======================================== */ + +article img, +.prose img, +.md-typeset img { + max-width: 100%; + height: auto; +} + +/* Full bleed images on mobile */ +@media (max-width: 639px) { + article img.full-bleed, + .prose img.full-bleed, + figure.full-bleed img { + margin-left: calc(-1 * var(--spacing-mobile)); + margin-right: calc(-1 * var(--spacing-mobile)); + max-width: calc(100% + 2 * var(--spacing-mobile)); + width: calc(100% + 2 * var(--spacing-mobile)); + border-radius: 0; + } +} + +/* Image grids */ +.image-grid { + display: grid; + grid-template-columns: repeat(2, 1fr); + gap: 1rem; +} + +@media (min-width: 640px) { + .image-grid { + grid-template-columns: repeat(3, 1fr); + } +} + +@media (min-width: 1024px) { + .image-grid { + grid-template-columns: repeat(4, 1fr); + } +} + +/* ======================================== + Responsive Tables + ======================================== */ + +/* Table wrapper for horizontal scroll */ +.table-responsive { + width: 100%; + overflow-x: auto; + -webkit-overflow-scrolling: touch; + margin: 1.5em 0; +} + +/* Scroll shadow indicators */ +.table-responsive { + position: relative; +} + +.table-responsive::before, +.table-responsive::after { + content: ''; + position: absolute; + top: 0; + bottom: 0; + width: 30px; + pointer-events: none; + opacity: 0; + transition: opacity 0.2s ease; + z-index: 1; +} + +.table-responsive::before { + left: 0; + background: linear-gradient(to right, var(--background, #fff) 0%, transparent 100%); +} + +.table-responsive::after { + right: 0; + background: linear-gradient(to left, var(--background, #fff) 0%, transparent 100%); +} + +/* Show shadows when scrollable */ +.table-responsive.can-scroll-left::before { + opacity: 1; +} + +.table-responsive.can-scroll-right::after { + opacity: 1; +} + +/* Card-style tables on mobile */ +@media (max-width: 639px) { + table.responsive-cards, + .table-cards table { + display: block; + } + + table.responsive-cards thead, + .table-cards thead { + display: none; + } + + table.responsive-cards tbody, + table.responsive-cards tr, + .table-cards tbody, + .table-cards tr { + display: block; + } + + table.responsive-cards tr, + .table-cards tr { + margin-bottom: 1rem; + padding: 1rem; + border: 1px solid var(--border, #e5e7eb); + border-radius: var(--radius-lg, 0.5rem); + background: var(--card, #fff); + } + + table.responsive-cards td, + .table-cards td { + display: flex; + justify-content: space-between; + padding: 0.5rem 0; + border: none; + border-bottom: 1px solid var(--border, #e5e7eb); + } + + table.responsive-cards td:last-child, + .table-cards td:last-child { + border-bottom: none; + } + + table.responsive-cards td::before, + .table-cards td::before { + content: attr(data-label); + font-weight: 600; + color: var(--muted-foreground, #6b7280); + margin-right: 1rem; + } +} + +/* ======================================== + Responsive Code Blocks + ======================================== */ + +@media (max-width: 639px) { + /* Smaller font on mobile */ + pre code, + .highlight code, + code { + font-size: 0.75rem; + line-height: 1.6; + } + + /* Reduce padding */ + pre code, + .highlight pre code { + padding: 0.75rem 1rem; + } + + /* Hide line numbers on very small screens */ + .highlight .linenos, + .highlight .linenodiv { + display: none; + } + + /* Simpler copy button */ + .copy-button, + .md-clipboard { + padding: 0.25rem 0.5rem; + font-size: 0.6875rem; + } +} + +/* ======================================== + Responsive Typography + ======================================== */ + +/* Fluid typography scale */ +@media (max-width: 639px) { + article, + .prose, + .md-typeset { + font-size: 0.9375rem; + line-height: 1.65; + } + + article h1, + .prose h1, + .md-typeset h1 { + font-size: 1.625rem; + line-height: 1.2; + } + + article h2, + .prose h2, + .md-typeset h2 { + font-size: 1.25rem; + margin-top: 2em; + } + + article h3, + .prose h3, + .md-typeset h3 { + font-size: 1.0625rem; + } + + /* Tighter spacing */ + article p, + .prose p { + margin-bottom: 1em; + } + + article ul, + article ol, + .prose ul, + .prose ol { + padding-left: 1.25rem; + } +} + +/* ======================================== + Responsive Admonitions & Cards + ======================================== */ + +@media (max-width: 639px) { + article .admonition, + .prose .admonition, + .md-typeset .admonition, + article details, + .prose details { + margin-left: calc(-1 * var(--spacing-mobile)); + margin-right: calc(-1 * var(--spacing-mobile)); + border-radius: 0; + border-left: none; + border-right: none; + border-top: 3px solid; + padding: 0.875rem 1rem; + } + + article .admonition-title, + .prose .admonition-title, + article details summary { + margin: -0.875rem -1rem 0.75rem; + padding: 0.625rem 1rem; + font-size: 0.875rem; + } + + /* Workflow adjustments */ + .workflow ol > li, + ol.workflow-steps > li { + padding-left: 2.5rem; + } + + .workflow ol > li::before, + ol.workflow-steps > li::before { + width: 1.5rem; + height: 1.5rem; + font-size: 0.75rem; + } + + .workflow ol > li::after, + ol.workflow-steps > li::after { + left: calc(0.75rem - 0.5px); + top: 1.5rem; + height: calc(100% - 1.5rem); + } +} + +/* ======================================== + Responsive Tabs + ======================================== */ + +@media (max-width: 639px) { + /* Scrollable tabs on mobile */ + .tabbed-labels, + .tabbed-set > .tabbed-labels { + overflow-x: auto; + -webkit-overflow-scrolling: touch; + scrollbar-width: none; + -ms-overflow-style: none; + } + + .tabbed-labels::-webkit-scrollbar { + display: none; + } + + .tabbed-labels > label, + .tabbed-set label { + flex-shrink: 0; + padding: 0.5rem 0.875rem; + font-size: 0.8125rem; + } +} + +/* ======================================== + Responsive Utility Classes + ======================================== */ + +/* Hide on specific breakpoints */ +@media (max-width: 639px) { + .hide-mobile, + .hidden-mobile, + .sm\:hidden { + display: none !important; + } +} + +@media (min-width: 640px) and (max-width: 767px) { + .hide-tablet-portrait, + .md\:hidden { + display: none !important; + } +} + +@media (min-width: 768px) and (max-width: 1023px) { + .hide-tablet, + .lg\:hidden { + display: none !important; + } +} + +@media (min-width: 1024px) { + .hide-desktop, + .hidden-desktop, + .xl\:hidden { + display: none !important; + } +} + +/* Show on specific breakpoints */ +@media (max-width: 639px) { + .show-mobile, + .visible-mobile { + display: block !important; + } +} + +@media (min-width: 640px) { + .show-mobile, + .visible-mobile { + display: none !important; + } +} + +@media (min-width: 1024px) { + .show-desktop, + .visible-desktop { + display: block !important; + } +} + +@media (max-width: 1023px) { + .show-desktop, + .visible-desktop { + display: none !important; + } +} + +/* Text alignment utilities */ +@media (max-width: 639px) { + .text-center-mobile { + text-align: center; + } + + .text-left-mobile { + text-align: left; + } +} + +/* Spacing utilities */ +@media (max-width: 639px) { + .p-mobile-0 { padding: 0 !important; } + .p-mobile-1 { padding: 0.25rem !important; } + .p-mobile-2 { padding: 0.5rem !important; } + .p-mobile-4 { padding: 1rem !important; } + + .m-mobile-0 { margin: 0 !important; } + .m-mobile-auto { margin: auto !important; } +} + +/* ======================================== + Print Styles + ======================================== */ + +@media print { + /* Hide non-essential elements */ + nav, + .md-sidebar, + .sidebar, + .nav-sidebar, + .toc, + .table-of-contents, + .scroll-to-top, + .mobile-menu-toggle, + .mobile-nav-overlay, + .copy-button, + .md-clipboard, + footer, + .md-footer { + display: none !important; + } + + /* Full width content */ + .md-content, + article, + .prose, + main { + width: 100% !important; + max-width: none !important; + margin: 0 !important; + padding: 0 !important; + } + + /* Print-friendly colors */ + body, + article, + .prose, + .md-typeset { + color: #000 !important; + background: #fff !important; + } + + /* Links show URL */ + a[href]::after { + content: " (" attr(href) ")"; + font-size: 0.8em; + color: #666; + } + + /* Don't show URL for internal links */ + a[href^="#"]::after, + a[href^="/"]::after { + content: ""; + } + + /* Avoid page breaks in bad places */ + h1, h2, h3, h4, h5, h6 { + page-break-after: avoid; + } + + pre, blockquote, table, figure, .admonition { + page-break-inside: avoid; + } + + /* Code blocks */ + pre, .highlight { + border: 1px solid #ccc; + background: #f5f5f5 !important; + overflow-x: visible; + white-space: pre-wrap; + word-wrap: break-word; + } + + /* Tables */ + table { + border-collapse: collapse; + } + + th, td { + border: 1px solid #ccc; + padding: 0.5rem; + } +} + +/* ======================================== + Landscape Orientation Fixes + ======================================== */ + +@media (max-height: 500px) and (orientation: landscape) { + /* Reduce header size on short landscape screens */ + .md-header, + header { + padding: 0.5rem 1rem; + } + + /* Reduce scroll margin for shorter header */ + h1[id], h2[id], h3[id], h4[id], h5[id], h6[id] { + scroll-margin-top: 50px; + } + + /* Compact navigation */ + .md-nav__link, + nav a { + padding: 0.375rem 0.75rem; + } +} + +/* ======================================== + High DPI / Retina Display + ======================================== */ + +@media (-webkit-min-device-pixel-ratio: 2), (min-resolution: 192dpi) { + /* Thinner borders on retina */ + .highlight, + pre, + table, + .admonition, + details { + border-width: 0.5px; + } +} + +/* ======================================== + Dark Mode Responsive Adjustments + ======================================== */ + +@media (max-width: 639px) { + .dark .mobile-nav-overlay, + .dark .mobile-nav-overlay { + background: rgba(0, 0, 0, 0.7); + } + + /* Dark mode scroll shadows */ + .dark .table-responsive::before, + .dark .table-responsive::before { + background: linear-gradient(to right, var(--background, #111827) 0%, transparent 100%); + } + + .dark .table-responsive::after, + .dark .table-responsive::after { + background: linear-gradient(to left, var(--background, #111827) 0%, transparent 100%); + } +} + +/* ======================================== + Safe Area Insets (Notch devices) + ======================================== */ + +@supports (padding: max(0px)) { + /* Account for notch on modern phones */ + .md-header, + header { + padding-left: max(1rem, env(safe-area-inset-left)); + padding-right: max(1rem, env(safe-area-inset-right)); + } + + .md-sidebar--primary, + nav.sidebar { + padding-left: max(1rem, env(safe-area-inset-left)); + } + + .mobile-menu-toggle { + bottom: max(1.5rem, calc(env(safe-area-inset-bottom) + 0.5rem)); + left: max(1.5rem, calc(env(safe-area-inset-left) + 0.5rem)); + } + + .scroll-to-top { + bottom: max(2rem, calc(env(safe-area-inset-bottom) + 0.5rem)); + right: max(2rem, calc(env(safe-area-inset-right) + 0.5rem)); + } +} diff --git a/docs/stylesheets/syntax-highlight.css b/docs/stylesheets/syntax-highlight.css new file mode 100644 index 00000000..7cfcf6ba --- /dev/null +++ b/docs/stylesheets/syntax-highlight.css @@ -0,0 +1,305 @@ +/* + * Syntax Highlighting Theme for OpenJudge + * Based on modern code editor color schemes + */ + +/* ======================================== + Light Mode Syntax Highlighting + ======================================== */ + +/* Keywords: from, import, def, return, lambda, class, etc. */ +.highlight .k, /* Keyword */ +.highlight .kn, /* Keyword.Namespace (import, from) */ +.highlight .kd, /* Keyword.Declaration (def, class) */ +.highlight .kr, /* Keyword.Reserved (return) */ +.highlight .kc, /* Keyword.Constant (True, False, None) */ +.codehilite .k, +.codehilite .kn, +.codehilite .kd, +.codehilite .kr, +.codehilite .kc { + color: #cf222e; + font-weight: 500; +} + +/* Strings: "...", '...' */ +.highlight .s, /* String */ +.highlight .s1, /* String.Single */ +.highlight .s2, /* String.Double */ +.highlight .se, /* String.Escape */ +.codehilite .s, +.codehilite .s1, +.codehilite .s2, +.codehilite .se { + color: #0a3069; +} + +/* Comments */ +.highlight .c, /* Comment */ +.highlight .c1, /* Comment.Single */ +.highlight .cm, /* Comment.Multiline */ +.highlight .cp, /* Comment.Preproc */ +.codehilite .c, +.codehilite .c1, +.codehilite .cm, +.codehilite .cp { + color: #6e7781; + font-style: italic; +} + +/* Function and Class Names */ +.highlight .nf, /* Name.Function */ +.highlight .nc, /* Name.Class */ +.codehilite .nf, +.codehilite .nc { + color: #0550ae; +} + +/* Builtin Functions: dict, bool, etc. */ +.highlight .nb, /* Name.Builtin */ +.highlight .bp, /* Name.Builtin.Pseudo */ +.codehilite .nb, +.codehilite .bp { + color: #953800; +} + +/* Numbers */ +.highlight .m, /* Number */ +.highlight .mi, /* Number.Integer */ +.highlight .mf, /* Number.Float */ +.codehilite .m, +.codehilite .mi, +.codehilite .mf { + color: #0550ae; +} + +/* Operators: =, ==, ->, etc. */ +.highlight .o, /* Operator */ +.highlight .ow, /* Operator.Word (and, or, in) */ +.codehilite .o, +.codehilite .ow { + color: #cf222e; + font-weight: 500; +} + +/* Punctuation: (), [], {}, :, , */ +.highlight .p, /* Punctuation */ +.codehilite .p { + color: #24292f; +} + +/* Variables and Parameters */ +.highlight .n, /* Name */ +.highlight .nv, /* Name.Variable */ +.codehilite .n, +.codehilite .nv { + color: #24292f; +} + +/* Decorators: @decorator */ +.highlight .nd, /* Name.Decorator */ +.codehilite .nd { + color: #8250df; +} + +/* Module/Package Names */ +.highlight .nn, /* Name.Namespace */ +.codehilite .nn { + color: #24292f; +} + +/* ======================================== + Dark Mode Syntax Highlighting + ======================================== */ + +.dark .highlight .k, +.dark .highlight .kn, +.dark .highlight .kd, +.dark .highlight .kr, +.dark .highlight .kc, +.dark .codehilite .k, +.dark .codehilite .kn, +.dark .codehilite .kd, +.dark .codehilite .kr, +.dark .codehilite .kc, +.dark .highlight .k, +.dark .highlight .kn, +.dark .highlight .kd, +.dark .highlight .kr, +.dark .highlight .kc { + color: #ff7b72; + font-weight: 500; +} + +.dark .highlight .s, +.dark .highlight .s1, +.dark .highlight .s2, +.dark .highlight .se, +.dark .codehilite .s, +.dark .codehilite .s1, +.dark .codehilite .s2, +.dark .codehilite .se, +.dark .highlight .s, +.dark .highlight .s1, +.dark .highlight .s2, +.dark .highlight .se { + color: #a5d6ff; +} + +.dark .highlight .c, +.dark .highlight .c1, +.dark .highlight .cm, +.dark .highlight .cp, +.dark .codehilite .c, +.dark .codehilite .c1, +.dark .codehilite .cm, +.dark .codehilite .cp, +.dark .highlight .c, +.dark .highlight .c1, +.dark .highlight .cm, +.dark .highlight .cp { + color: #8b949e; + font-style: italic; +} + +.dark .highlight .nf, +.dark .highlight .nc, +.dark .codehilite .nf, +.dark .codehilite .nc, +.dark .highlight .nf, +.dark .highlight .nc { + color: #d2a8ff; +} + +.dark .highlight .nb, +.dark .highlight .bp, +.dark .codehilite .nb, +.dark .codehilite .bp, +.dark .highlight .nb, +.dark .highlight .bp { + color: #ffa657; +} + +.dark .highlight .m, +.dark .highlight .mi, +.dark .highlight .mf, +.dark .codehilite .m, +.dark .codehilite .mi, +.dark .codehilite .mf, +.dark .highlight .m, +.dark .highlight .mi, +.dark .highlight .mf { + color: #79c0ff; +} + +.dark .highlight .o, +.dark .highlight .ow, +.dark .codehilite .o, +.dark .codehilite .ow, +.dark .highlight .o, +.dark .highlight .ow { + color: #ff7b72; + font-weight: 500; +} + +.dark .highlight .p, +.dark .codehilite .p, +.dark .highlight .p, +.dark .codehilite .p { + color: #c9d1d9; +} + +.dark .highlight .n, +.dark .highlight .nv, +.dark .codehilite .n, +.dark .codehilite .nv, +.dark .highlight .n, +.dark .highlight .nv { + color: #c9d1d9; +} + +.dark .highlight .nd, +.dark .codehilite .nd, +.dark .highlight .nd, +.dark .codehilite .nd { + color: #d2a8ff; +} + +.dark .highlight .nn, +.dark .codehilite .nn, +.dark .highlight .nn, +.dark .codehilite .nn { + color: #c9d1d9; +} + +/* ======================================== + Special Highlighting + ======================================== */ + +/* Highlighted lines */ +.highlight .hll, +.codehilite .hll { + background-color: rgba(255, 213, 0, 0.15); + display: block; + margin: 0 -1.25rem; + padding: 0 1.25rem; +} + +.dark .highlight .hll, +.dark .codehilite .hll, +.dark .highlight .hll, +.dark .codehilite .hll { + background-color: rgba(255, 213, 0, 0.1); +} + +/* Error highlighting */ +.highlight .err, +.codehilite .err { + color: #cf222e; +} + +.dark .highlight .err, +.dark .codehilite .err, +.dark .highlight .err, +.dark .codehilite .err { + color: #ff7b72; +} + +/* ======================================== + Language-Specific Adjustments + ======================================== */ + +/* Python-specific */ +.highlight .language-python .nv, +.codehilite .language-python .nv { + color: #24292f; +} + +.dark .highlight .language-python .nv, +.dark .codehilite .language-python .nv, +.dark .highlight .language-python .nv, +.dark .codehilite .language-python .nv { + color: #c9d1d9; +} + +/* JavaScript/TypeScript-specific */ +.highlight .language-javascript .kd, +.highlight .language-typescript .kd, +.codehilite .language-javascript .kd, +.codehilite .language-typescript .kd { + color: #cf222e; + font-weight: 500; +} + +/* JSON-specific */ +.highlight .language-json .nd, +.codehilite .language-json .nd { + color: #0550ae; +} + +.dark .highlight .language-json .nd, +.dark .codehilite .language-json .nd, +.dark .highlight .language-json .nd, +.dark .codehilite .language-json .nd { + color: #79c0ff; +} diff --git a/docs/stylesheets/tabbed-code.css b/docs/stylesheets/tabbed-code.css new file mode 100644 index 00000000..220b4d3d --- /dev/null +++ b/docs/stylesheets/tabbed-code.css @@ -0,0 +1,410 @@ +/* + * Tabbed Code Blocks Styling for pymdownx.tabbed alternate_style with shadcn/ui theme + * Override base.css styles for alternate_style: true + */ + +/* ======================================== + Override base.css tabbed-set styles for alternate_style + Using higher specificity selectors + ======================================== */ + +/* Reset base.css flex-wrap that breaks alternate layout */ +article .tabbed-set.tabbed-alternate { + flex-flow: column nowrap !important; + display: flex !important; + margin: 1.5em 0 !important; + /* Fallback for browsers without OKLCH support */ + border: 1px solid #e5e7eb !important; + border: 1px solid var(--border, #e5e7eb) !important; + border-radius: 0.5rem !important; + overflow: hidden !important; + background: #ffffff !important; + background: var(--background, #ffffff) !important; + box-shadow: none !important; +} + +/* Special styling for Workflow tabs - no border */ +article .tabbed-set.tabbed-alternate:has(.workflow) { + border: none !important; + border-radius: 0 !important; + background: transparent !important; + overflow: visible !important; +} + +/* Hide radio inputs */ +article .tabbed-set.tabbed-alternate > input[type="radio"] { + position: absolute !important; + width: 1px !important; + height: 1px !important; + padding: 0 !important; + margin: -1px !important; + overflow: hidden !important; + clip: rect(0, 0, 0, 0) !important; + white-space: nowrap !important; + border: 0 !important; + display: block !important; /* Override base.css display: none */ +} + +/* Tab labels container */ +article .tabbed-set.tabbed-alternate > .tabbed-labels { + display: flex !important; + flex-direction: row !important; + background: transparent !important; + /* Fallback for browsers without OKLCH support */ + border-bottom: 1px solid #e5e7eb !important; + border-bottom: 1px solid var(--border, #e5e7eb) !important; + border-top-left-radius: 0 !important; + border-top-right-radius: 0 !important; + padding: 0.5rem 0.5rem 0 0.5rem !important; + padding-right: 3rem !important; + margin: 0 !important; + overflow-x: auto !important; + order: 1 !important; + position: relative !important; + scrollbar-width: none !important; + -ms-overflow-style: none !important; +} + +/* Hide scrollbar for webkit browsers */ +article .tabbed-set.tabbed-alternate > .tabbed-labels::-webkit-scrollbar { + display: none !important; +} + +/* Tab label buttons - override base.css label styles */ +article .tabbed-set.tabbed-alternate > .tabbed-labels > label { + padding: 0.5rem 1rem !important; + font-size: 0.875rem !important; + font-weight: 500 !important; + line-height: 1.25rem !important; + /* Fallback for browsers without OKLCH support */ + color: #6b7280 !important; + color: var(--muted-foreground, #6b7280) !important; + cursor: pointer !important; + border: none !important; + border-bottom: 2px solid transparent !important; + border-radius: 0 !important; + background: transparent !important; + transition: all 0.2s ease !important; + white-space: nowrap !important; + position: relative !important; + margin: 0 !important; + margin-bottom: -1px !important; + display: inline-flex !important; + align-items: center !important; + user-select: none !important; + order: unset !important; + flex-basis: auto !important; +} + +article .tabbed-set.tabbed-alternate > .tabbed-labels > label:first-of-type { + margin-left: 0 !important; +} + +article .tabbed-set.tabbed-alternate > .tabbed-labels > label:hover { + color: var(--foreground) !important; + background: transparent !important; +} + +/* Tab content container - override base.css */ +article .tabbed-set.tabbed-alternate > .tabbed-content { + display: block !important; + position: relative !important; + background: transparent !important; + order: 2 !important; + flex-basis: auto !important; + border-top: none !important; + padding: 0 !important; + margin: 0 !important; +} + +/* Ensure content flows naturally inside the border */ +article .tabbed-set.tabbed-alternate .tabbed-block { + padding: 0 !important; + margin: 0 !important; +} + +/* Individual tab blocks - hide by default */ +article .tabbed-set.tabbed-alternate > .tabbed-content > .tabbed-block { + display: none !important; +} + +/* ======================================== + Active Tab States - CSS :checked method + ======================================== */ + +/* Active label - Tab 1 */ +article .tabbed-set.tabbed-alternate > input:nth-child(1):checked ~ .tabbed-labels > label:nth-child(1) { + color: #14b8a6 !important; + background: transparent !important; + border-bottom-color: #14b8a6 !important; + box-shadow: none !important; +} + +/* Active label - Tab 2 */ +article .tabbed-set.tabbed-alternate > input:nth-child(2):checked ~ .tabbed-labels > label:nth-child(2) { + color: #14b8a6 !important; + background: transparent !important; + border-bottom-color: #14b8a6 !important; + box-shadow: none !important; +} + +/* Active label - Tab 3 */ +article .tabbed-set.tabbed-alternate > input:nth-child(3):checked ~ .tabbed-labels > label:nth-child(3) { + color: #14b8a6 !important; + background: transparent !important; + border-bottom-color: #14b8a6 !important; + box-shadow: none !important; +} + +/* Active label - Tab 4 */ +article .tabbed-set.tabbed-alternate > input:nth-child(4):checked ~ .tabbed-labels > label:nth-child(4) { + color: #14b8a6 !important; + background: transparent !important; + border-bottom-color: #14b8a6 !important; + box-shadow: none !important; +} + +/* Active label - Tab 5 */ +article .tabbed-set.tabbed-alternate > input:nth-child(5):checked ~ .tabbed-labels > label:nth-child(5) { + color: #14b8a6 !important; + background: transparent !important; + border-bottom-color: #14b8a6 !important; + box-shadow: none !important; +} + +/* Active label - Tab 6 */ +article .tabbed-set.tabbed-alternate > input:nth-child(6):checked ~ .tabbed-labels > label:nth-child(6) { + color: #14b8a6 !important; + background: transparent !important; + border-bottom-color: #14b8a6 !important; + box-shadow: none !important; +} + +/* Show active tab content */ +article .tabbed-set.tabbed-alternate > input:nth-child(1):checked ~ .tabbed-content > .tabbed-block:nth-child(1) { + display: block !important; +} + +article .tabbed-set.tabbed-alternate > input:nth-child(2):checked ~ .tabbed-content > .tabbed-block:nth-child(2) { + display: block !important; +} + +article .tabbed-set.tabbed-alternate > input:nth-child(3):checked ~ .tabbed-content > .tabbed-block:nth-child(3) { + display: block !important; +} + +article .tabbed-set.tabbed-alternate > input:nth-child(4):checked ~ .tabbed-content > .tabbed-block:nth-child(4) { + display: block !important; +} + +article .tabbed-set.tabbed-alternate > input:nth-child(5):checked ~ .tabbed-content > .tabbed-block:nth-child(5) { + display: block !important; +} + +article .tabbed-set.tabbed-alternate > input:nth-child(6):checked ~ .tabbed-content > .tabbed-block:nth-child(6) { + display: block !important; +} + +/* JS fallback method */ +article .tabbed-set.tabbed-alternate > .tabbed-content > .tabbed-block.tabbed-block--active { + display: block !important; +} + +article .tabbed-set.tabbed-alternate > .tabbed-labels > label.tabbed-label--active, +article .tabbed-set.tabbed-alternate > .tabbed-labels > label[data-active="true"] { + color: #14b8a6 !important; + background: transparent !important; + border-bottom-color: #14b8a6 !important; + box-shadow: none !important; +} + +/* ======================================== + Code Block Styling Inside Tabs + ======================================== */ + +article .tabbed-set.tabbed-alternate .tabbed-block .highlight, +article .tabbed-set.tabbed-alternate .tabbed-block .codehilite { + margin: 0 !important; + border-radius: 0 !important; + border: none !important; + position: relative !important; +} + +article .tabbed-set.tabbed-alternate .tabbed-block pre { + margin: 0 !important; + border-radius: 0 !important; + border: none !important; + background: transparent !important; + position: relative !important; +} + +article .tabbed-set.tabbed-alternate .tabbed-block pre code { + display: block !important; + padding: 1rem 1.25rem !important; + overflow-x: auto !important; + font-size: 0.8125rem !important; + line-height: 1.7 !important; +} + +/* Copy button for tabbed code blocks */ +article .tabbed-set.tabbed-alternate .tabbed-labels .copy-button { + position: absolute !important; + top: 0.625rem !important; + right: 0.75rem !important; + padding: 0.375rem 0.5rem !important; + font-size: 0.75rem !important; + font-weight: 500 !important; + /* Fallback for browsers without OKLCH support */ + color: #6b7280 !important; + color: var(--muted-foreground, #6b7280) !important; + background: #ffffff !important; + background: var(--background, #ffffff) !important; + border: 1px solid #e5e7eb !important; + border: 1px solid var(--border, #e5e7eb) !important; + border-radius: 0.375rem !important; + cursor: pointer !important; + opacity: 0 !important; + transition: all 0.15s ease !important; + z-index: 10 !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + min-width: 2rem !important; + height: 2rem !important; +} + +article .tabbed-set.tabbed-alternate .tabbed-labels .copy-button svg { + width: 1rem !important; + height: 1rem !important; + display: block !important; +} + +article .tabbed-set.tabbed-alternate:hover .copy-button { + opacity: 1 !important; +} + +article .tabbed-set.tabbed-alternate .copy-button:hover { + color: var(--foreground, #1f2937) !important; + background: var(--muted, #f3f4f6) !important; + border-color: var(--border, #d1d5db) !important; +} + +article .tabbed-set.tabbed-alternate .copy-button.copied { + color: var(--success, #10b981) !important; + border-color: var(--success, #10b981) !important; +} + +/* ======================================== + Dark Mode Support + ======================================== */ + +.dark article .tabbed-set.tabbed-alternate { + /* Fallback for browsers without OKLCH support */ + background: #0a0a0a !important; + background: var(--background, #0a0a0a) !important; + border: 1px solid #374151 !important; + border: 1px solid var(--border, #374151) !important; +} + +/* Special styling for Workflow tabs in dark mode - no border */ +.dark article .tabbed-set.tabbed-alternate:has(.workflow), +.dark article .tabbed-set.tabbed-alternate:has(.workflow) { + border: none !important; + background: transparent !important; +} + +.dark article .tabbed-set.tabbed-alternate > .tabbed-labels { + background: transparent !important; + /* Fallback for browsers without OKLCH support */ + border-bottom-color: #374151 !important; + border-bottom-color: var(--border, #374151) !important; +} + +.dark article .tabbed-set.tabbed-alternate > .tabbed-labels > label { + color: var(--muted-foreground) !important; +} + +.dark article .tabbed-set.tabbed-alternate > .tabbed-labels > label:hover { + color: var(--foreground) !important; + background: rgba(255, 255, 255, 0.05) !important; +} + +.dark article .tabbed-set.tabbed-alternate > input:checked ~ .tabbed-labels > label.tabbed-label--active, +.dark article .tabbed-set.tabbed-alternate > .tabbed-labels > label[data-active="true"], +.dark article .tabbed-set.tabbed-alternate > input:nth-child(1):checked ~ .tabbed-labels > label:nth-child(1), +.dark article .tabbed-set.tabbed-alternate > input:nth-child(2):checked ~ .tabbed-labels > label:nth-child(2), +.dark article .tabbed-set.tabbed-alternate > input:nth-child(3):checked ~ .tabbed-labels > label:nth-child(3), +.dark article .tabbed-set.tabbed-alternate > input:nth-child(4):checked ~ .tabbed-labels > label:nth-child(4), +.dark article .tabbed-set.tabbed-alternate > input:nth-child(5):checked ~ .tabbed-labels > label:nth-child(5), +.dark article .tabbed-set.tabbed-alternate > input:nth-child(6):checked ~ .tabbed-labels > label:nth-child(6) { + color: #2dd4bf !important; + background: transparent !important; + border-bottom-color: #2dd4bf !important; + box-shadow: none !important; +} + +.dark article .tabbed-set.tabbed-alternate .tabbed-block pre, +.dark article .tabbed-set.tabbed-alternate .tabbed-block pre { + background: transparent !important; + border: none !important; +} + +.dark article .tabbed-set.tabbed-alternate .tabbed-block .highlight, +.dark article .tabbed-set.tabbed-alternate .tabbed-block .codehilite, +.dark article .tabbed-set.tabbed-alternate .tabbed-block .highlight, +.dark article .tabbed-set.tabbed-alternate .tabbed-block .codehilite { + border: none !important; +} + +.dark article .tabbed-set.tabbed-alternate .tabbed-labels .copy-button { + /* Fallback for browsers without OKLCH support */ + background: #1f2937 !important; + background: var(--background, #1f2937) !important; + border-color: #374151 !important; + border-color: var(--border, #374151) !important; + color: #9ca3af !important; + color: var(--muted-foreground, #9ca3af) !important; +} + +.dark article .tabbed-set.tabbed-alternate .copy-button:hover { + /* Fallback for browsers without OKLCH support */ + background: #374151 !important; + background: var(--muted, #374151) !important; + color: #e5e7eb !important; + color: var(--foreground, #e5e7eb) !important; +} + +/* ======================================== + Responsive Design + ======================================== */ + +@media (max-width: 640px) { + article .tabbed-set.tabbed-alternate > .tabbed-labels { + padding: 0.375rem 2.75rem 0 0.375rem !important; + } + + article .tabbed-set.tabbed-alternate > .tabbed-labels > label { + padding: 0.375rem 0.625rem !important; + font-size: 0.8125rem !important; + margin: 0 0.125rem !important; + } + + article .tabbed-set.tabbed-alternate .tabbed-block pre code { + padding: 0.875rem 1rem !important; + font-size: 0.75rem !important; + } + + article .tabbed-set.tabbed-alternate .tabbed-labels .copy-button { + opacity: 1 !important; + min-width: 1.75rem !important; + height: 1.75rem !important; + padding: 0.25rem !important; + top: 0.5rem !important; + right: 0.5rem !important; + } + + article .tabbed-set.tabbed-alternate .tabbed-labels .copy-button svg { + width: 0.875rem !important; + height: 0.875rem !important; + } +} diff --git a/docs/stylesheets/table-enhancements.css b/docs/stylesheets/table-enhancements.css new file mode 100644 index 00000000..78be88b2 --- /dev/null +++ b/docs/stylesheets/table-enhancements.css @@ -0,0 +1,321 @@ +/* + * Table Enhancements for OpenJudge Documentation + * Phase 1: 表格样式增强 + * + * Features: + * - 表头样式优化 + * - 单元格垂直居中 + * - 斑马条纹 + * - 悬停高亮 + * - 响应式滚动 + * - 暗色模式支持 + */ + +/* ======================================== + Base Table Styles + ======================================== */ + +article table, +.prose table, +.md-typeset table:not([class]) { + width: 100%; + border-collapse: separate; + border-spacing: 0; + margin: 1.5em 0; + font-size: 0.9375rem; + line-height: 1.6; + overflow: hidden; + border: 1px solid var(--border, #e5e7eb); + border-radius: var(--radius-lg, 0.5rem); +} + +/* ======================================== + Table Header + ======================================== */ + +article table thead, +.prose table thead, +.md-typeset table:not([class]) thead { + background: var(--muted, #f9fafb); +} + +article table th, +.prose table th, +.md-typeset table:not([class]) th { + padding: 0.75rem 1rem; + font-weight: 600; + font-size: 0.875rem; + text-align: left; + color: var(--foreground, #111827); + border-bottom: 1px solid var(--border, #e5e7eb); + white-space: nowrap; +} + +/* First header cell - round top-left corner */ +article table th:first-child, +.prose table th:first-child, +.md-typeset table:not([class]) th:first-child { + border-top-left-radius: calc(var(--radius-lg, 0.5rem) - 1px); +} + +/* Last header cell - round top-right corner */ +article table th:last-child, +.prose table th:last-child, +.md-typeset table:not([class]) th:last-child { + border-top-right-radius: calc(var(--radius-lg, 0.5rem) - 1px); +} + +/* ======================================== + Table Body & Cells + ======================================== */ + +article table td, +.prose table td, +.md-typeset table:not([class]) td { + padding: 0.75rem 1rem; + vertical-align: middle; + color: var(--foreground, #374151); + border-bottom: 1px solid var(--border, #e5e7eb); + line-height: 1.5; +} + +/* Remove border from last row */ +article table tbody tr:last-child td, +.prose table tbody tr:last-child td, +.md-typeset table:not([class]) tbody tr:last-child td { + border-bottom: none; +} + +/* ======================================== + Zebra Stripes (Alternating Row Colors) + ======================================== */ + +article table tbody tr:nth-child(even), +.prose table tbody tr:nth-child(even), +.md-typeset table:not([class]) tbody tr:nth-child(even) { + background: rgba(0, 0, 0, 0.02); +} + +/* ======================================== + Row Hover Effect + ======================================== */ + +article table tbody tr, +.prose table tbody tr, +.md-typeset table:not([class]) tbody tr { + transition: background-color 0.15s ease; +} + +article table tbody tr:hover, +.prose table tbody tr:hover, +.md-typeset table:not([class]) tbody tr:hover { + background: rgba(0, 0, 0, 0.04); +} + +/* ======================================== + Code in Table Cells + ======================================== */ + +article table code, +.prose table code, +.md-typeset table:not([class]) code { + font-size: 0.8125rem; + padding: 0.125rem 0.375rem; + background: var(--muted, #f3f4f6); + border-radius: 0.25rem; + font-family: 'JetBrains Mono', ui-monospace, monospace; +} + +/* ======================================== + Links in Table Cells + ======================================== */ + +article table a, +.prose table a, +.md-typeset table:not([class]) a { + color: var(--primary, #3b82f6); + text-decoration: none; + font-weight: 500; +} + +article table a:hover, +.prose table a:hover, +.md-typeset table:not([class]) a:hover { + text-decoration: underline; + text-underline-offset: 2px; +} + +/* ======================================== + Dark Mode + ======================================== */ + +.dark article table, +.dark .prose table, +.dark .md-typeset table:not([class]), +.dark article table, +.dark .prose table, +.dark .md-typeset table:not([class]) { + border-color: var(--border, #374151); +} + +.dark article table thead, +.dark .prose table thead, +.dark .md-typeset table:not([class]) thead, +.dark article table thead, +.dark .prose table thead, +.dark .md-typeset table:not([class]) thead { + background: var(--muted, #1f2937); +} + +.dark article table th, +.dark .prose table th, +.dark .md-typeset table:not([class]) th, +.dark article table th, +.dark .prose table th, +.dark .md-typeset table:not([class]) th { + color: var(--foreground, #f9fafb); + border-bottom-color: var(--border, #374151); +} + +.dark article table td, +.dark .prose table td, +.dark .md-typeset table:not([class]) td, +.dark article table td, +.dark .prose table td, +.dark .md-typeset table:not([class]) td { + color: var(--foreground, #e5e7eb); + border-bottom-color: var(--border, #374151); +} + +/* Dark mode zebra stripes */ +.dark article table tbody tr:nth-child(even), +.dark .prose table tbody tr:nth-child(even), +.dark .md-typeset table:not([class]) tbody tr:nth-child(even), +.dark article table tbody tr:nth-child(even), +.dark .prose table tbody tr:nth-child(even), +.dark .md-typeset table:not([class]) tbody tr:nth-child(even) { + background: rgba(255, 255, 255, 0.02); +} + +/* Dark mode hover */ +.dark article table tbody tr:hover, +.dark .prose table tbody tr:hover, +.dark .md-typeset table:not([class]) tbody tr:hover, +.dark article table tbody tr:hover, +.dark .prose table tbody tr:hover, +.dark .md-typeset table:not([class]) tbody tr:hover { + background: rgba(255, 255, 255, 0.05); +} + +/* Dark mode code in tables */ +.dark article table code, +.dark .prose table code, +.dark .md-typeset table:not([class]) code, +.dark article table code, +.dark .prose table code, +.dark .md-typeset table:not([class]) code { + background: var(--muted, #374151); +} + +/* ======================================== + Responsive Table (Horizontal Scroll) + ======================================== */ + +.table-responsive, +.md-typeset .table-responsive { + width: 100%; + overflow-x: auto; + -webkit-overflow-scrolling: touch; + margin: 1.5em 0; +} + +.table-responsive table, +.md-typeset .table-responsive table { + margin: 0; + min-width: 600px; +} + +/* Scroll shadow indicators */ +.table-responsive { + position: relative; +} + +.table-responsive::before, +.table-responsive::after { + content: ''; + position: absolute; + top: 0; + bottom: 0; + width: 20px; + pointer-events: none; + opacity: 0; + transition: opacity 0.2s ease; + z-index: 1; +} + +.table-responsive::before { + left: 0; + background: linear-gradient(to right, var(--background, #fff), transparent); +} + +.table-responsive::after { + right: 0; + background: linear-gradient(to left, var(--background, #fff), transparent); +} + +.table-responsive.scroll-left::before, +.table-responsive.scroll-right::after { + opacity: 1; +} + +/* ======================================== + Compact Table Variant + ======================================== */ + +article table.compact th, +article table.compact td, +.prose table.compact th, +.prose table.compact td { + padding: 0.5rem 0.75rem; + font-size: 0.8125rem; +} + +/* ======================================== + Wide Table Variant + ======================================== */ + +article table.wide, +.prose table.wide { + min-width: 100%; +} + +/* ======================================== + Mobile Responsive + ======================================== */ + +@media (max-width: 640px) { + article table, + .prose table, + .md-typeset table:not([class]) { + font-size: 0.875rem; + display: block; + overflow-x: auto; + -webkit-overflow-scrolling: touch; + } + + article table th, + article table td, + .prose table th, + .prose table td, + .md-typeset table:not([class]) th, + .md-typeset table:not([class]) td { + padding: 0.625rem 0.75rem; + white-space: nowrap; + } + + article table th, + .prose table th, + .md-typeset table:not([class]) th { + font-size: 0.8125rem; + } +} diff --git a/docs/stylesheets/tuner_v2.md b/docs/stylesheets/tuner_v2.md new file mode 100644 index 00000000..c19509cd --- /dev/null +++ b/docs/stylesheets/tuner_v2.md @@ -0,0 +1,81 @@ + + + +# Type 1-1: AgentScope Agents + +```python +class ExampleMathLearn(Workflow): + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), # 🌟 this will do the trick + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), + max_iters=2, + ) + + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) + +``` + + + +# Type 1-2: AgentScope Agents: Triple-M (Multi-Role, Multi-Agent, Multi-Turn) Case + +```python +roles = ["werewolf"] * 3 + ["villager"] * 3 + ["seer", "witch", "hunter"] +players = [] +for i, agent_role in enumerate(roles): + if agent_role != "werewolf": + chosen_model_for_current_agent = OpenAIChatModel(model_name="qwen-max", stream=False) + else: + chosen_model_for_current_agent = OpenAIChatModel(model_name="qwen-plus", stream=False) + players += [ReActAgent( + name=f"Player{i + 1}", + sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), + model=agentscope_model, + model=tuner.as_agentscope_model( + agent_name=f"Player{i + 1}", + target_tag=agent_role, # 🌟 tag agents with their role + debug_model=chosen_model_for_current_agent # 🌟 assign a debug model, ONLY used when we are NOT training this agent + ) + formatter=OpenAIMultiAgentFormatter(), + )] +``` + + + +# Type 2: Raw OpenAI SDK Agents + +```python + +import openai +client = openai.OpenAI(api_key='dummy-api-key') + +response = client.chat.completions.create( + model="gpt-3.5-turbo", # You can replace this with "gpt-4" if available + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello! Tell me a joke about programming."} + ], + max_tokens=100, # Limit the response length + temperature=0.7 # Control the randomness of the output +) + + +``` diff --git a/docs/stylesheets/workflow.css b/docs/stylesheets/workflow.css new file mode 100644 index 00000000..2bac2c94 --- /dev/null +++ b/docs/stylesheets/workflow.css @@ -0,0 +1,290 @@ +/* Workflow Component Styling */ +/* A step-by-step workflow display with numbered badges and vertical connector lines */ + +/* Workflow container */ +.workflow { + margin: 1.5em 0; + padding: 1rem 0; +} + +/* Workflow title */ +.workflow-title { + font-size: 1.5rem; + font-weight: 600; + margin-bottom: 1rem; + color: #1f2937; +} + +.dark .workflow-title, +.dark .workflow-title { + color: #f3f4f6; +} + +/* Workflow steps list */ +.workflow ol, +ol.workflow-steps { + list-style: none; + padding: 0; + margin: 0; + counter-reset: workflow-counter; +} + +/* Individual step item */ +.workflow ol > li, +ol.workflow-steps > li { + position: relative; + padding: 0 0 1.75rem 3.5rem; + margin: 0; + counter-increment: workflow-counter; +} + +.workflow ol > li:last-child, +ol.workflow-steps > li:last-child { + padding-bottom: 0; +} + +/* Step number badge */ +.workflow ol > li::before, +ol.workflow-steps > li::before { + content: counter(workflow-counter); + position: absolute; + left: 0; + top: 0; + width: 2rem; + height: 2rem; + background: #f9fafb; + border: 1.5px solid #d1d5db; + border-radius: 50%; + display: flex; + align-items: center; + justify-content: center; + font-size: 0.875rem; + font-weight: 500; + color: #6b7280; + z-index: 2; +} + +.dark .workflow ol > li::before, +.dark ol.workflow-steps > li::before, +.dark .workflow ol > li::before, +.dark ol.workflow-steps > li::before { + background: #374151; + border-color: #4b5563; + color: #d1d5db; +} + +/* Vertical connector line */ +.workflow ol > li::after, +ol.workflow-steps > li::after { + content: ''; + position: absolute; + left: calc(1rem - 0.5px); + top: 2rem; + width: 1px; + height: calc(100% - 2rem); + background: #d1d5db; + z-index: 1; +} + +/* Hide connector line on last item */ +.workflow ol > li:last-child::after, +ol.workflow-steps > li:last-child::after { + display: none; +} + +.dark .workflow ol > li::after, +.dark ol.workflow-steps > li::after, +.dark .workflow ol > li::after, +.dark ol.workflow-steps > li::after { + background: #4b5563; +} + +/* Step title */ +.workflow ol > li strong:first-child, +ol.workflow-steps > li strong:first-child, +.workflow-step-title { + display: block; + font-size: 1rem; + font-weight: 600; + color: #111827; + margin-bottom: 0.5rem; + line-height: 2rem; +} + +.dark .workflow ol > li strong:first-child, +.dark ol.workflow-steps > li strong:first-child, +.dark .workflow-step-title, +.dark .workflow ol > li strong:first-child, +.dark ol.workflow-steps > li strong:first-child, +.dark .workflow-step-title { + color: #f9fafb; +} + +/* Step description */ +.workflow ol > li p, +ol.workflow-steps > li p, +.workflow-step-desc { + margin: 0 0 0.5rem 0; + font-size: 0.9375rem; + color: #4b5563; + line-height: 1.6; +} + +.dark .workflow ol > li p, +.dark ol.workflow-steps > li p, +.dark .workflow-step-desc, +.dark .workflow ol > li p, +.dark ol.workflow-steps > li p, +.dark .workflow-step-desc { + color: #9ca3af; +} + +/* Links in workflow */ +.workflow a, +ol.workflow-steps a { + color: #059669; + text-decoration: underline; + text-underline-offset: 2px; +} + +.workflow a:hover, +ol.workflow-steps a:hover { + color: #047857; +} + +.dark .workflow a, +.dark ol.workflow-steps a, +.dark .workflow a, +.dark ol.workflow-steps a { + color: #34d399; +} + +.dark .workflow a:hover, +.dark ol.workflow-steps a:hover, +.dark .workflow a:hover, +.dark ol.workflow-steps a:hover { + color: #6ee7b7; +} + +/* Nested list in workflow steps */ +.workflow ol > li ul, +ol.workflow-steps > li ul { + margin: 0.5rem 0 0 0; + padding-left: 1.25rem; + list-style: disc; +} + +.workflow ol > li ul li, +ol.workflow-steps > li ul li { + padding: 0.25rem 0; + font-size: 0.9375rem; + color: #4b5563; +} + +/* Reset nested list items - don't apply workflow counter styles */ +.workflow ol > li ul li::before, +ol.workflow-steps > li ul li::before, +.workflow ol > li ul li::after, +ol.workflow-steps > li ul li::after { + display: none; + content: none; +} + +.dark .workflow ol > li ul li, +.dark ol.workflow-steps > li ul li, +.dark .workflow ol > li ul li, +.dark ol.workflow-steps > li ul li { + color: #9ca3af; +} + +/* Workflow inside tabs - seamless integration */ +.tabbed-block .workflow { + padding: 1.5rem 1rem; +} + +.tabbed-block .workflow ol, +.tabbed-block ol.workflow-steps { + margin: 0; +} + +.tabbed-block .workflow ol > li, +.tabbed-block ol.workflow-steps > li { + padding: 0 0 1.75rem 3.5rem; +} + +/* Single workflow (no tabs needed) */ +.workflow-single { + margin: 1.5em 0; + padding: 0; + border: 1px solid #e5e7eb; + border-radius: 0.5rem; + overflow: hidden; +} + +.workflow-single .workflow-header { + font-size: 0.875rem; + font-weight: 500; + color: #111827; + padding: 0.75rem 1rem; + border-bottom: 1px solid #e5e7eb; + background: #f9fafb; +} + +.dark .workflow-single, +.dark .workflow-single { + border-color: #374151; +} + +.dark .workflow-single .workflow-header, +.dark .workflow-single .workflow-header { + color: #f3f4f6; + border-bottom-color: #374151; + background: #1f2937; +} + +.workflow-single .workflow { + padding: 1.5rem 1rem 1rem; + margin: 0; +} + +/* Compact workflow variant */ +.workflow-compact ol > li, +ol.workflow-steps.compact > li { + padding: 0 0 1.25rem 3rem; +} + +.workflow-compact ol > li::before, +ol.workflow-steps.compact > li::before { + width: 1.75rem; + height: 1.75rem; + font-size: 0.8125rem; +} + +.workflow-compact ol > li::after, +ol.workflow-steps.compact > li::after { + left: calc(0.875rem - 0.5px); + top: 1.75rem; + height: calc(100% - 1.75rem); +} + +/* Responsive */ +@media (max-width: 640px) { + .workflow ol > li, + ol.workflow-steps > li { + padding-left: 3rem; + } + + .workflow ol > li::before, + ol.workflow-steps > li::before { + width: 1.75rem; + height: 1.75rem; + font-size: 0.8125rem; + } + + .workflow ol > li::after, + ol.workflow-steps > li::after { + left: calc(0.875rem - 0.5px); + top: 1.75rem; + height: calc(100% - 1.75rem); + } +} diff --git a/install.sh b/install.sh new file mode 100755 index 00000000..2306bad0 --- /dev/null +++ b/install.sh @@ -0,0 +1,2080 @@ +#!/bin/sh +# shellcheck shell=dash +# shellcheck disable=SC2039 # local is non-POSIX +# +# Licensed under the MIT license +# , at your +# option. This file may not be copied, modified, or distributed +# except according to those terms. + +# This runs on Unix shells like bash/dash/ksh/zsh. It uses the common `local` +# extension. Note: Most shells limit `local` to 1 var per line, contra bash. + +# Some versions of ksh have no `local` keyword. Alias it to `typeset`, but +# beware this makes variables global with f()-style function syntax in ksh93. +# mksh has this alias by default. +has_local() { + # shellcheck disable=SC2034 # deliberately unused + local _has_local +} + +has_local 2>/dev/null || alias local=typeset + +set -u + +APP_NAME="uv" +APP_VERSION="0.9.22" +# Look for GitHub Enterprise-style base URL first +if [ -n "${UV_INSTALLER_GHE_BASE_URL:-}" ]; then + INSTALLER_BASE_URL="$UV_INSTALLER_GHE_BASE_URL" +else + INSTALLER_BASE_URL="${UV_INSTALLER_GITHUB_BASE_URL:-https://github.com}" +fi +if [ -n "${UV_DOWNLOAD_URL:-}" ]; then + ARTIFACT_DOWNLOAD_URL="$UV_DOWNLOAD_URL" +elif [ -n "${INSTALLER_DOWNLOAD_URL:-}" ]; then + ARTIFACT_DOWNLOAD_URL="$INSTALLER_DOWNLOAD_URL" +else + ARTIFACT_DOWNLOAD_URL="${INSTALLER_BASE_URL}/astral-sh/uv/releases/download/0.9.22" +fi +if [ -n "${UV_PRINT_VERBOSE:-}" ]; then + PRINT_VERBOSE="$UV_PRINT_VERBOSE" +else + PRINT_VERBOSE=${INSTALLER_PRINT_VERBOSE:-0} +fi +if [ -n "${UV_PRINT_QUIET:-}" ]; then + PRINT_QUIET="$UV_PRINT_QUIET" +else + PRINT_QUIET=${INSTALLER_PRINT_QUIET:-0} +fi +if [ -n "${UV_NO_MODIFY_PATH:-}" ]; then + NO_MODIFY_PATH="$UV_NO_MODIFY_PATH" +else + NO_MODIFY_PATH=${INSTALLER_NO_MODIFY_PATH:-0} +fi +if [ "${UV_DISABLE_UPDATE:-0}" = "1" ]; then + INSTALL_UPDATER=0 +else + INSTALL_UPDATER=1 +fi +UNMANAGED_INSTALL="${UV_UNMANAGED_INSTALL:-}" +if [ -n "${UNMANAGED_INSTALL}" ]; then + NO_MODIFY_PATH=1 + INSTALL_UPDATER=0 +fi +AUTH_TOKEN="${UV_GITHUB_TOKEN:-}" + +read -r RECEIPT <&2 + say_verbose " from $_url" 1>&2 + say_verbose " to $_file" 1>&2 + + ensure mkdir -p "$_dir" + + if ! downloader "$_url" "$_file"; then + say "failed to download $_url" + say "this may be a standard network error, but it may also indicate" + say "that $APP_NAME's release process is not working. When in doubt" + say "please feel free to open an issue!" + exit 1 + fi + + if [ -n "${_checksum_style:-}" ]; then + verify_checksum "$_file" "$_checksum_style" "$_checksum_value" + else + say "no checksums to verify" + fi + + # ...and then the updater, if it exists + if [ -n "$_updater_name" ] && [ "$INSTALL_UPDATER" = "1" ]; then + local _updater_url="$ARTIFACT_DOWNLOAD_URL/$_updater_name" + # This renames the artifact while doing the download, removing the + # target triple and leaving just the appname-update format + local _updater_file="$_dir/$APP_NAME-update" + + if ! downloader "$_updater_url" "$_updater_file"; then + say "failed to download $_updater_url" + say "this may be a standard network error, but it may also indicate" + say "that $APP_NAME's release process is not working. When in doubt" + say "please feel free to open an issue!" + exit 1 + fi + + # Add the updater to the list of binaries to install + _bins="$_bins $APP_NAME-update" + fi + + # unpack the archive + case "$_zip_ext" in + ".zip") + ensure unzip -q "$_file" -d "$_dir" + ;; + + ".tar."*) + ensure tar xf "$_file" --strip-components 1 -C "$_dir" + ;; + *) + err "unknown archive format: $_zip_ext" + ;; + esac + + install "$_dir" "$_bins" "$_libs" "$_staticlibs" "$_arch" "$@" + local _retval=$? + if [ "$_retval" != 0 ]; then + return "$_retval" + fi + + ignore rm -rf "$_dir" + + # Install the install receipt + if [ "$INSTALL_UPDATER" = "1" ]; then + if ! mkdir -p "$RECEIPT_HOME"; then + err "unable to create receipt directory at $RECEIPT_HOME" + else + echo "$RECEIPT" > "$RECEIPT_HOME/$APP_NAME-receipt.json" + # shellcheck disable=SC2320 + local _retval=$? + fi + else + local _retval=0 + fi + + return "$_retval" +} + +# Replaces $HOME with the variable name for display to the user, +# only if $HOME is defined. +replace_home() { + local _str="$1" + + if [ -n "${HOME:-}" ]; then + echo "$_str" | sed "s,$HOME,\$HOME," + else + echo "$_str" + fi +} + +json_binary_aliases() { + local _arch="$1" + + case "$_arch" in + "aarch64-apple-darwin") + echo '{}' + ;; + "aarch64-pc-windows-gnu") + echo '{}' + ;; + "aarch64-unknown-linux-gnu") + echo '{}' + ;; + "aarch64-unknown-linux-musl-dynamic") + echo '{}' + ;; + "aarch64-unknown-linux-musl-static") + echo '{}' + ;; + "arm-unknown-linux-gnueabihf") + echo '{}' + ;; + "arm-unknown-linux-musl-dynamiceabihf") + echo '{}' + ;; + "arm-unknown-linux-musl-staticeabihf") + echo '{}' + ;; + "armv7-unknown-linux-gnueabihf") + echo '{}' + ;; + "armv7-unknown-linux-musl-dynamiceabihf") + echo '{}' + ;; + "armv7-unknown-linux-musl-staticeabihf") + echo '{}' + ;; + "i686-pc-windows-gnu") + echo '{}' + ;; + "i686-unknown-linux-gnu") + echo '{}' + ;; + "i686-unknown-linux-musl-dynamic") + echo '{}' + ;; + "i686-unknown-linux-musl-static") + echo '{}' + ;; + "powerpc64-unknown-linux-gnu") + echo '{}' + ;; + "powerpc64le-unknown-linux-gnu") + echo '{}' + ;; + "riscv64gc-unknown-linux-gnu") + echo '{}' + ;; + "s390x-unknown-linux-gnu") + echo '{}' + ;; + "x86_64-apple-darwin") + echo '{}' + ;; + "x86_64-pc-windows-gnu") + echo '{}' + ;; + "x86_64-unknown-linux-gnu") + echo '{}' + ;; + "x86_64-unknown-linux-musl-dynamic") + echo '{}' + ;; + "x86_64-unknown-linux-musl-static") + echo '{}' + ;; + *) + echo '{}' + ;; + esac +} + +aliases_for_binary() { + local _bin="$1" + local _arch="$2" + + case "$_arch" in + "aarch64-apple-darwin") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "aarch64-pc-windows-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "aarch64-unknown-linux-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "aarch64-unknown-linux-musl-dynamic") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "aarch64-unknown-linux-musl-static") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "arm-unknown-linux-gnueabihf") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "arm-unknown-linux-musl-dynamiceabihf") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "arm-unknown-linux-musl-staticeabihf") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "armv7-unknown-linux-gnueabihf") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "armv7-unknown-linux-musl-dynamiceabihf") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "armv7-unknown-linux-musl-staticeabihf") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "i686-pc-windows-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "i686-unknown-linux-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "i686-unknown-linux-musl-dynamic") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "i686-unknown-linux-musl-static") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "powerpc64-unknown-linux-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "powerpc64le-unknown-linux-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "riscv64gc-unknown-linux-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "s390x-unknown-linux-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "x86_64-apple-darwin") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "x86_64-pc-windows-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "x86_64-unknown-linux-gnu") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "x86_64-unknown-linux-musl-dynamic") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + "x86_64-unknown-linux-musl-static") + case "$_bin" in + *) + echo "" + ;; + esac + ;; + *) + echo "" + ;; + esac +} + +select_archive_for_arch() { + local _true_arch="$1" + local _archive + + # try each archive, checking runtime conditions like libc versions + # accepting the first one that matches, as it's the best match + case "$_true_arch" in + "aarch64-apple-darwin") + _archive="uv-aarch64-apple-darwin.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-x86_64-apple-darwin.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "aarch64-pc-windows-gnu") + _archive="uv-aarch64-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "aarch64-pc-windows-msvc") + _archive="uv-aarch64-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-x86_64-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-i686-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "aarch64-unknown-linux-gnu") + _archive="uv-aarch64-unknown-linux-gnu.tar.gz" + if ! check_glibc "2" "28"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-aarch64-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "aarch64-unknown-linux-musl-dynamic") + _archive="uv-aarch64-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "aarch64-unknown-linux-musl-static") + _archive="uv-aarch64-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "arm-unknown-linux-gnueabihf") + _archive="uv-arm-unknown-linux-musleabihf.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "arm-unknown-linux-musl-dynamiceabihf") + _archive="uv-arm-unknown-linux-musleabihf.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "arm-unknown-linux-musl-staticeabihf") + _archive="uv-arm-unknown-linux-musleabihf.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "armv7-unknown-linux-gnueabihf") + _archive="uv-armv7-unknown-linux-gnueabihf.tar.gz" + if ! check_glibc "2" "17"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-armv7-unknown-linux-musleabihf.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "armv7-unknown-linux-musl-dynamiceabihf") + _archive="uv-armv7-unknown-linux-musleabihf.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "armv7-unknown-linux-musl-staticeabihf") + _archive="uv-armv7-unknown-linux-musleabihf.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "i686-pc-windows-gnu") + _archive="uv-i686-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "i686-pc-windows-msvc") + _archive="uv-i686-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "i686-unknown-linux-gnu") + _archive="uv-i686-unknown-linux-gnu.tar.gz" + if ! check_glibc "2" "17"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-i686-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "i686-unknown-linux-musl-dynamic") + _archive="uv-i686-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "i686-unknown-linux-musl-static") + _archive="uv-i686-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "powerpc64-unknown-linux-gnu") + _archive="uv-powerpc64-unknown-linux-gnu.tar.gz" + if ! check_glibc "2" "17"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "powerpc64le-unknown-linux-gnu") + _archive="uv-powerpc64le-unknown-linux-gnu.tar.gz" + if ! check_glibc "2" "17"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "riscv64gc-unknown-linux-gnu") + _archive="uv-riscv64gc-unknown-linux-gnu.tar.gz" + if ! check_glibc "2" "31"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "s390x-unknown-linux-gnu") + _archive="uv-s390x-unknown-linux-gnu.tar.gz" + if ! check_glibc "2" "17"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "x86_64-apple-darwin") + _archive="uv-x86_64-apple-darwin.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "x86_64-pc-windows-gnu") + _archive="uv-x86_64-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "x86_64-pc-windows-msvc") + _archive="uv-x86_64-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-i686-pc-windows-msvc.zip" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "x86_64-unknown-linux-gnu") + _archive="uv-x86_64-unknown-linux-gnu.tar.gz" + if ! check_glibc "2" "17"; then + _archive="" + fi + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + _archive="uv-x86_64-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "x86_64-unknown-linux-musl-dynamic") + _archive="uv-x86_64-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + "x86_64-unknown-linux-musl-static") + _archive="uv-x86_64-unknown-linux-musl.tar.gz" + if [ -n "$_archive" ]; then + echo "$_archive" + return 0 + fi + ;; + *) + err "there isn't a download for your platform $_true_arch" + ;; + esac + err "no compatible downloads were found for your platform $_true_arch" +} + +check_glibc() { + local _min_glibc_major="$1" + local _min_glibc_series="$2" + + # Parsing version out from line 1 like: + # ldd (Ubuntu GLIBC 2.35-0ubuntu3.1) 2.35 + _local_glibc="$(ldd --version | awk -F' ' '{ if (FNR<=1) print $NF }')" + + if [ "$(echo "${_local_glibc}" | awk -F. '{ print $1 }')" = "$_min_glibc_major" ] && [ "$(echo "${_local_glibc}" | awk -F. '{ print $2 }')" -ge "$_min_glibc_series" ]; then + return 0 + else + say "System glibc version (\`${_local_glibc}') is too old; checking alternatives" >&2 + return 1 + fi +} + +# See discussion of late-bound vs early-bound for why we use single-quotes with env vars +# shellcheck disable=SC2016 +install() { + # This code needs to both compute certain paths for itself to write to, and + # also write them to shell/rc files so that they can look them up to e.g. + # add them to PATH. This requires an active distinction between paths + # and expressions that can compute them. + # + # The distinction lies in when we want env-vars to be evaluated. For instance + # if we determine that we want to install to $HOME/.myapp, which do we add + # to e.g. $HOME/.profile: + # + # * early-bound: export PATH="/home/myuser/.myapp:$PATH" + # * late-bound: export PATH="$HOME/.myapp:$PATH" + # + # In this case most people would prefer the late-bound version, but in other + # cases the early-bound version might be a better idea. In particular when using + # other env-vars than $HOME, they are more likely to be only set temporarily + # for the duration of this install script, so it's more advisable to erase their + # existence with early-bounding. + # + # This distinction is handled by "double-quotes" (early) vs 'single-quotes' (late). + # + # However if we detect that "$SOME_VAR/..." is a subdir of $HOME, we try to rewrite + # it to be '$HOME/...' to get the best of both worlds. + # + # This script has a few different variants, the most complex one being the + # CARGO_HOME version which attempts to install things to Cargo's bin dir, + # potentially setting up a minimal version if the user hasn't ever installed Cargo. + # + # In this case we need to: + # + # * Install to $HOME/.cargo/bin/ + # * Create a shell script at $HOME/.cargo/env that: + # * Checks if $HOME/.cargo/bin/ is on PATH + # * and if not prepends it to PATH + # * Edits $INFERRED_HOME/.profile to run $HOME/.cargo/env (if the line doesn't exist) + # + # To do this we need these 4 values: + + # The actual path we're going to install to + local _install_dir + # The directory C dynamic/static libraries install to + local _lib_install_dir + # The install prefix we write to the receipt. + # For organized install methods like CargoHome, which have + # subdirectories, this is the root without `/bin`. For other + # methods, this is the same as `_install_dir`. + local _receipt_install_dir + # Path to the an shell script that adds install_dir to PATH + local _env_script_path + # Potentially-late-bound version of install_dir to write env_script + local _install_dir_expr + # Potentially-late-bound version of env_script_path to write to rcfiles like $HOME/.profile + local _env_script_path_expr + # Forces the install to occur at this path, not the default + local _force_install_dir + # Which install layout to use - "flat" or "hierarchical" + local _install_layout="unspecified" + # A list of binaries which are shadowed in the PATH + local _shadowed_bins="" + + # Check the newer app-specific variable before falling back + # to the older generic one + if [ -n "${UV_INSTALL_DIR:-}" ]; then + _force_install_dir="$UV_INSTALL_DIR" + _install_layout="flat" + elif [ -n "${CARGO_DIST_FORCE_INSTALL_DIR:-}" ]; then + _force_install_dir="$CARGO_DIST_FORCE_INSTALL_DIR" + _install_layout="flat" + elif [ -n "$UNMANAGED_INSTALL" ]; then + _force_install_dir="$UNMANAGED_INSTALL" + _install_layout="flat" + fi + + # Check if the install layout should be changed from `flat` to `cargo-home` + # for backwards compatible updates of applications that switched layouts. + if [ -n "${_force_install_dir:-}" ]; then + if [ "$_install_layout" = "flat" ]; then + # If the install directory is targeting the Cargo home directory, then + # we assume this application was previously installed that layout + if [ "$_force_install_dir" = "${CARGO_HOME:-${INFERRED_HOME:-}/.cargo}" ]; then + _install_layout="cargo-home" + fi + fi + fi + + # Before actually consulting the configured install strategy, see + # if we're overriding it. + if [ -n "${_force_install_dir:-}" ]; then + case "$_install_layout" in + "hierarchical") + _install_dir="$_force_install_dir/bin" + _lib_install_dir="$_force_install_dir/lib" + _receipt_install_dir="$_force_install_dir" + _env_script_path="$_force_install_dir/env" + _install_dir_expr="$(replace_home "$_force_install_dir/bin")" + _env_script_path_expr="$(replace_home "$_force_install_dir/env")" + ;; + "cargo-home") + _install_dir="$_force_install_dir/bin" + _lib_install_dir="$_force_install_dir/bin" + _receipt_install_dir="$_force_install_dir" + _env_script_path="$_force_install_dir/env" + _install_dir_expr="$(replace_home "$_force_install_dir/bin")" + _env_script_path_expr="$(replace_home "$_force_install_dir/env")" + ;; + "flat") + _install_dir="$_force_install_dir" + _lib_install_dir="$_force_install_dir" + _receipt_install_dir="$_install_dir" + _env_script_path="$_force_install_dir/env" + _install_dir_expr="$(replace_home "$_force_install_dir")" + _env_script_path_expr="$(replace_home "$_force_install_dir/env")" + ;; + *) + err "Unrecognized install layout: $_install_layout" + ;; + esac + fi + if [ -z "${_install_dir:-}" ]; then + _install_layout="flat" + # Install to $XDG_BIN_HOME + if [ -n "${XDG_BIN_HOME:-}" ]; then + _install_dir="$XDG_BIN_HOME" + _lib_install_dir="$_install_dir" + _receipt_install_dir="$_install_dir" + _env_script_path="$XDG_BIN_HOME/env" + _install_dir_expr="$(replace_home "$_install_dir")" + _env_script_path_expr="$(replace_home "$_env_script_path")" + fi + fi + if [ -z "${_install_dir:-}" ]; then + _install_layout="flat" + # Install to $XDG_DATA_HOME/../bin + if [ -n "${XDG_DATA_HOME:-}" ]; then + _install_dir="$XDG_DATA_HOME/../bin" + _lib_install_dir="$_install_dir" + _receipt_install_dir="$_install_dir" + _env_script_path="$XDG_DATA_HOME/../bin/env" + _install_dir_expr="$(replace_home "$_install_dir")" + _env_script_path_expr="$(replace_home "$_env_script_path")" + fi + fi + if [ -z "${_install_dir:-}" ]; then + _install_layout="flat" + # Install to $HOME/.local/bin + if [ -n "${INFERRED_HOME:-}" ]; then + _install_dir="$INFERRED_HOME/.local/bin" + _lib_install_dir="$INFERRED_HOME/.local/bin" + _receipt_install_dir="$_install_dir" + _env_script_path="$INFERRED_HOME/.local/bin/env" + _install_dir_expr="$INFERRED_HOME_EXPRESSION/.local/bin" + _env_script_path_expr="$INFERRED_HOME_EXPRESSION/.local/bin/env" + fi + fi + + if [ -z "$_install_dir_expr" ]; then + err "could not find a valid path to install to!" + fi + + # Identical to the sh version, just with a .fish file extension + # We place it down here to wait until it's been assigned in every + # path. + _fish_env_script_path="${_env_script_path}.fish" + _fish_env_script_path_expr="${_env_script_path_expr}.fish" + + # Replace the temporary cargo home with the calculated one + RECEIPT=$(echo "$RECEIPT" | sed "s,AXO_INSTALL_PREFIX,$_receipt_install_dir,") + # Also replace the aliases with the arch-specific one + RECEIPT=$(echo "$RECEIPT" | sed "s'\"binary_aliases\":{}'\"binary_aliases\":$(json_binary_aliases "$_arch")'") + # And replace the install layout + RECEIPT=$(echo "$RECEIPT" | sed "s'\"install_layout\":\"unspecified\"'\"install_layout\":\"$_install_layout\"'") + if [ "$NO_MODIFY_PATH" = "1" ]; then + RECEIPT=$(echo "$RECEIPT" | sed "s'\"modify_path\":true'\"modify_path\":false'") + fi + + say "installing to $_install_dir" + ensure mkdir -p "$_install_dir" + ensure mkdir -p "$_lib_install_dir" + + # copy all the binaries to the install dir + local _src_dir="$1" + local _bins="$2" + local _libs="$3" + local _staticlibs="$4" + local _arch="$5" + for _bin_name in $_bins; do + local _bin="$_src_dir/$_bin_name" + ensure mv "$_bin" "$_install_dir" + # unzip seems to need this chmod + ensure chmod +x "$_install_dir/$_bin_name" + for _dest in $(aliases_for_binary "$_bin_name" "$_arch"); do + ln -sf "$_install_dir/$_bin_name" "$_install_dir/$_dest" + done + say " $_bin_name" + done + # Like the above, but no aliases + for _lib_name in $_libs; do + local _lib="$_src_dir/$_lib_name" + ensure mv "$_lib" "$_lib_install_dir" + # unzip seems to need this chmod + ensure chmod +x "$_lib_install_dir/$_lib_name" + say " $_lib_name" + done + for _lib_name in $_staticlibs; do + local _lib="$_src_dir/$_lib_name" + ensure mv "$_lib" "$_lib_install_dir" + # unzip seems to need this chmod + ensure chmod +x "$_lib_install_dir/$_lib_name" + say " $_lib_name" + done + + say "everything's installed!" + + # Avoid modifying the users PATH if they are managing their PATH manually + case :$PATH: + in *:$_install_dir:*) NO_MODIFY_PATH=1 ;; + *) ;; + esac + + if [ "0" = "$NO_MODIFY_PATH" ]; then + add_install_dir_to_ci_path "$_install_dir" + add_install_dir_to_path "$_install_dir_expr" "$_env_script_path" "$_env_script_path_expr" ".profile" "sh" + exit1=$? + shotgun_install_dir_to_path "$_install_dir_expr" "$_env_script_path" "$_env_script_path_expr" ".profile .bashrc .bash_profile .bash_login" "sh" + exit2=$? + add_install_dir_to_path "$_install_dir_expr" "$_env_script_path" "$_env_script_path_expr" ".zshrc .zshenv" "sh" + exit3=$? + # This path may not exist by default + ensure mkdir -p "$INFERRED_HOME/.config/fish/conf.d" + exit4=$? + add_install_dir_to_path "$_install_dir_expr" "$_fish_env_script_path" "$_fish_env_script_path_expr" ".config/fish/conf.d/$APP_NAME.env.fish" "fish" + exit5=$? + + if [ "${exit1:-0}" = 1 ] || [ "${exit2:-0}" = 1 ] || [ "${exit3:-0}" = 1 ] || [ "${exit4:-0}" = 1 ] || [ "${exit5:-0}" = 1 ]; then + say "" + say "To add $_install_dir_expr to your PATH, either restart your shell or run:" + say "" + say " source $_env_script_path_expr (sh, bash, zsh)" + say " source $_fish_env_script_path_expr (fish)" + fi + fi + + _shadowed_bins="$(check_for_shadowed_bins "$_install_dir" "$_bins")" + if [ -n "$_shadowed_bins" ]; then + warn "The following commands are shadowed by other commands in your PATH:$_shadowed_bins" + fi +} + +check_for_shadowed_bins() { + local _install_dir="$1" + local _bins="$2" + local _shadow + + for _bin_name in $_bins; do + _shadow="$(command -v "$_bin_name")" + if [ -n "$_shadow" ] && [ "$_shadow" != "$_install_dir/$_bin_name" ]; then + _shadowed_bins="$_shadowed_bins $_bin_name" + fi + done + + echo "$_shadowed_bins" +} + +print_home_for_script() { + local script="$1" + + local _home + case "$script" in + # zsh has a special ZDOTDIR directory, which if set + # should be considered instead of $HOME + .zsh*) + if [ -n "${ZDOTDIR:-}" ]; then + _home="$ZDOTDIR" + else + _home="$INFERRED_HOME" + fi + ;; + *) + _home="$INFERRED_HOME" + ;; + esac + + echo "$_home" +} + +add_install_dir_to_ci_path() { + # Attempt to do CI-specific rituals to get the install-dir on PATH faster + local _install_dir="$1" + + # If GITHUB_PATH is present, then write install_dir to the file it refs. + # After each GitHub Action, the contents will be added to PATH. + # So if you put a curl | sh for this script in its own "run" step, + # the next step will have this dir on PATH. + # + # Note that GITHUB_PATH will not resolve any variables, so we in fact + # want to write install_dir and not install_dir_expr + if [ -n "${GITHUB_PATH:-}" ]; then + ensure echo "$_install_dir" >> "$GITHUB_PATH" + fi +} + +add_install_dir_to_path() { + # Edit rcfiles ($HOME/.profile) to add install_dir to $PATH + # + # We do this slightly indirectly by creating an "env" shell script which checks if install_dir + # is on $PATH already, and prepends it if not. The actual line we then add to rcfiles + # is to just source that script. This allows us to blast it into lots of different rcfiles and + # have it run multiple times without causing problems. It's also specifically compatible + # with the system rustup uses, so that we don't conflict with it. + local _install_dir_expr="$1" + local _env_script_path="$2" + local _env_script_path_expr="$3" + local _rcfiles="$4" + local _shell="$5" + + if [ -n "${INFERRED_HOME:-}" ]; then + local _target + local _home + + # Find the first file in the array that exists and choose + # that as our target to write to + for _rcfile_relative in $_rcfiles; do + _home="$(print_home_for_script "$_rcfile_relative")" + local _rcfile="$_home/$_rcfile_relative" + + if [ -f "$_rcfile" ]; then + _target="$_rcfile" + break + fi + done + + # If we didn't find anything, pick the first entry in the + # list as the default to create and write to + if [ -z "${_target:-}" ]; then + local _rcfile_relative + _rcfile_relative="$(echo "$_rcfiles" | awk '{ print $1 }')" + _home="$(print_home_for_script "$_rcfile_relative")" + _target="$_home/$_rcfile_relative" + fi + + # `source x` is an alias for `. x`, and the latter is more portable/actually-posix. + # This apparently comes up a lot on freebsd. It's easy enough to always add + # the more robust line to rcfiles, but when telling the user to apply the change + # to their current shell ". x" is pretty easy to misread/miscopy, so we use the + # prettier "source x" line there. Hopefully people with Weird Shells are aware + # this is a thing and know to tweak it (or just restart their shell). + local _robust_line=". \"$_env_script_path_expr\"" + local _pretty_line="source \"$_env_script_path_expr\"" + + # Add the env script if it doesn't already exist + if [ ! -f "$_env_script_path" ]; then + say_verbose "creating $_env_script_path" + if [ "$_shell" = "sh" ]; then + write_env_script_sh "$_install_dir_expr" "$_env_script_path" + else + write_env_script_fish "$_install_dir_expr" "$_env_script_path" + fi + else + say_verbose "$_env_script_path already exists" + fi + + # Check if the line is already in the rcfile + # grep: 0 if matched, 1 if no match, and 2 if an error occurred + # + # Ideally we could use quiet grep (-q), but that makes "match" and "error" + # have the same behaviour, when we want "no match" and "error" to be the same + # (on error we want to create the file, which >> conveniently does) + # + # We search for both kinds of line here just to do the right thing in more cases. + if ! grep -F "$_robust_line" "$_target" > /dev/null 2>/dev/null && \ + ! grep -F "$_pretty_line" "$_target" > /dev/null 2>/dev/null + then + # If the script now exists, add the line to source it to the rcfile + # (This will also create the rcfile if it doesn't exist) + if [ -f "$_env_script_path" ]; then + local _line + # Fish has deprecated `.` as an alias for `source` and + # it will be removed in a later version. + # https://fishshell.com/docs/current/cmds/source.html + # By contrast, `.` is the traditional syntax in sh and + # `source` isn't always supported in all circumstances. + if [ "$_shell" = "fish" ]; then + _line="$_pretty_line" + else + _line="$_robust_line" + fi + say_verbose "adding $_line to $_target" + # prepend an extra newline in case the user's file is missing a trailing one + ensure echo "" >> "$_target" + ensure echo "$_line" >> "$_target" + return 1 + fi + else + say_verbose "$_install_dir already on PATH" + fi + fi +} + +shotgun_install_dir_to_path() { + # Edit rcfiles ($HOME/.profile) to add install_dir to $PATH + # (Shotgun edition - write to all provided files that exist rather than just the first) + local _install_dir_expr="$1" + local _env_script_path="$2" + local _env_script_path_expr="$3" + local _rcfiles="$4" + local _shell="$5" + + if [ -n "${INFERRED_HOME:-}" ]; then + local _found=false + local _home + + for _rcfile_relative in $_rcfiles; do + _home="$(print_home_for_script "$_rcfile_relative")" + local _rcfile_abs="$_home/$_rcfile_relative" + + if [ -f "$_rcfile_abs" ]; then + _found=true + add_install_dir_to_path "$_install_dir_expr" "$_env_script_path" "$_env_script_path_expr" "$_rcfile_relative" "$_shell" + fi + done + + # Fall through to previous "create + write to first file in list" behavior + if [ "$_found" = false ]; then + add_install_dir_to_path "$_install_dir_expr" "$_env_script_path" "$_env_script_path_expr" "$_rcfiles" "$_shell" + fi + fi +} + +write_env_script_sh() { + # write this env script to the given path (this cat/EOF stuff is a "heredoc" string) + local _install_dir_expr="$1" + local _env_script_path="$2" + ensure cat < "$_env_script_path" +#!/bin/sh +# add binaries to PATH if they aren't added yet +# affix colons on either side of \$PATH to simplify matching +case ":\${PATH}:" in + *:"$_install_dir_expr":*) + ;; + *) + # Prepending path in case a system-installed binary needs to be overridden + export PATH="$_install_dir_expr:\$PATH" + ;; +esac +EOF +} + +write_env_script_fish() { + # write this env script to the given path (this cat/EOF stuff is a "heredoc" string) + local _install_dir_expr="$1" + local _env_script_path="$2" + ensure cat < "$_env_script_path" +if not contains "$_install_dir_expr" \$PATH + # Prepending path in case a system-installed binary needs to be overridden + set -x PATH "$_install_dir_expr" \$PATH +end +EOF +} + +get_current_exe() { + # Returns the executable used for system architecture detection + # This is only run on Linux + local _current_exe + if test -L /proc/self/exe ; then + _current_exe=/proc/self/exe + else + warn "Unable to find /proc/self/exe. System architecture detection might be inaccurate." + if test -n "$SHELL" ; then + _current_exe=$SHELL + else + need_cmd /bin/sh + _current_exe=/bin/sh + fi + warn "Falling back to $_current_exe." + fi + echo "$_current_exe" +} + +get_bitness() { + need_cmd head + # Architecture detection without dependencies beyond coreutils. + # ELF files start out "\x7fELF", and the following byte is + # 0x01 for 32-bit and + # 0x02 for 64-bit. + # The printf builtin on some shells like dash only supports octal + # escape sequences, so we use those. + local _current_exe=$1 + local _current_exe_head + _current_exe_head=$(head -c 5 "$_current_exe") + if [ "$_current_exe_head" = "$(printf '\177ELF\001')" ]; then + echo 32 + elif [ "$_current_exe_head" = "$(printf '\177ELF\002')" ]; then + echo 64 + else + err "unknown platform bitness" + fi +} + +is_host_amd64_elf() { + local _current_exe=$1 + + need_cmd head + need_cmd tail + # ELF e_machine detection without dependencies beyond coreutils. + # Two-byte field at offset 0x12 indicates the CPU, + # but we're interested in it being 0x3E to indicate amd64, or not that. + local _current_exe_machine + _current_exe_machine=$(head -c 19 "$_current_exe" | tail -c 1) + [ "$_current_exe_machine" = "$(printf '\076')" ] +} + +get_endianness() { + local _current_exe=$1 + local cputype=$2 + local suffix_eb=$3 + local suffix_el=$4 + + # detect endianness without od/hexdump, like get_bitness() does. + need_cmd head + need_cmd tail + + local _current_exe_endianness + _current_exe_endianness="$(head -c 6 "$_current_exe" | tail -c 1)" + if [ "$_current_exe_endianness" = "$(printf '\001')" ]; then + echo "${cputype}${suffix_el}" + elif [ "$_current_exe_endianness" = "$(printf '\002')" ]; then + echo "${cputype}${suffix_eb}" + else + err "unknown platform endianness" + fi +} + +# Detect the Linux/LoongArch UAPI flavor, with all errors being non-fatal. +# Returns 0 or 234 in case of successful detection, 1 otherwise (/tmp being +# noexec, or other causes). +check_loongarch_uapi() { + need_cmd base64 + + local _tmp + if ! _tmp="$(ensure mktemp)"; then + return 1 + fi + + # Minimal Linux/LoongArch UAPI detection, exiting with 0 in case of + # upstream ("new world") UAPI, and 234 (-EINVAL truncated) in case of + # old-world (as deployed on several early commercial Linux distributions + # for LoongArch). + # + # See https://gist.github.com/xen0n/5ee04aaa6cecc5c7794b9a0c3b65fc7f for + # source to this helper binary. + ignore base64 -d > "$_tmp" <&1 | grep -q 'musl'; then + _clibtype="musl-dynamic" + else + # Assume all other linuxes are glibc (even if wrong, static libc fallback will apply) + _clibtype="gnu" + fi + fi + + if [ "$_ostype" = Darwin ]; then + # Darwin `uname -m` can lie due to Rosetta shenanigans. If you manage to + # invoke a native shell binary and then a native uname binary, you can + # get the real answer, but that's hard to ensure, so instead we use + # `sysctl` (which doesn't lie) to check for the actual architecture. + if [ "$_cputype" = i386 ]; then + # Handling i386 compatibility mode in older macOS versions (<10.15) + # running on x86_64-based Macs. + # Starting from 10.15, macOS explicitly bans all i386 binaries from running. + # See: + + # Avoid `sysctl: unknown oid` stderr output and/or non-zero exit code. + if sysctl hw.optional.x86_64 2> /dev/null || true | grep -q ': 1'; then + _cputype=x86_64 + fi + elif [ "$_cputype" = x86_64 ]; then + # Handling x86-64 compatibility mode (a.k.a. Rosetta 2) + # in newer macOS versions (>=11) running on arm64-based Macs. + # Rosetta 2 is built exclusively for x86-64 and cannot run i386 binaries. + + # Avoid `sysctl: unknown oid` stderr output and/or non-zero exit code. + if sysctl hw.optional.arm64 2> /dev/null || true | grep -q ': 1'; then + _cputype=arm64 + fi + fi + fi + + if [ "$_ostype" = SunOS ]; then + # Both Solaris and illumos presently announce as "SunOS" in "uname -s" + # so use "uname -o" to disambiguate. We use the full path to the + # system uname in case the user has coreutils uname first in PATH, + # which has historically sometimes printed the wrong value here. + if [ "$(/usr/bin/uname -o)" = illumos ]; then + _ostype=illumos + fi + + # illumos systems have multi-arch userlands, and "uname -m" reports the + # machine hardware name; e.g., "i86pc" on both 32- and 64-bit x86 + # systems. Check for the native (widest) instruction set on the + # running kernel: + if [ "$_cputype" = i86pc ]; then + _cputype="$(isainfo -n)" + fi + fi + + local _current_exe + case "$_ostype" in + + Android) + _ostype=linux-android + ;; + + Linux) + _current_exe=$(get_current_exe) + _ostype=unknown-linux-$_clibtype + _bitness=$(get_bitness "$_current_exe") + ;; + + FreeBSD) + _ostype=unknown-freebsd + ;; + + NetBSD) + _ostype=unknown-netbsd + ;; + + DragonFly) + _ostype=unknown-dragonfly + ;; + + Darwin) + _ostype=apple-darwin + ;; + + illumos) + _ostype=unknown-illumos + ;; + + MINGW* | MSYS* | CYGWIN* | Windows_NT) + _ostype=pc-windows-gnu + ;; + + *) + err "unrecognized OS type: $_ostype" + ;; + + esac + + case "$_cputype" in + + i386 | i486 | i686 | i786 | x86) + _cputype=i686 + ;; + + xscale | arm) + _cputype=arm + if [ "$_ostype" = "linux-android" ]; then + _ostype=linux-androideabi + fi + ;; + + armv6l) + _cputype=arm + if [ "$_ostype" = "linux-android" ]; then + _ostype=linux-androideabi + else + _ostype="${_ostype}eabihf" + fi + ;; + + armv7l | armv8l) + _cputype=armv7 + if [ "$_ostype" = "linux-android" ]; then + _ostype=linux-androideabi + else + _ostype="${_ostype}eabihf" + fi + ;; + + aarch64 | arm64) + _cputype=aarch64 + ;; + + x86_64 | x86-64 | x64 | amd64) + _cputype=x86_64 + ;; + + mips) + _cputype=$(get_endianness "$_current_exe" mips '' el) + ;; + + mips64) + if [ "$_bitness" -eq 64 ]; then + # only n64 ABI is supported for now + _ostype="${_ostype}abi64" + _cputype=$(get_endianness "$_current_exe" mips64 '' el) + fi + ;; + + ppc) + _cputype=powerpc + ;; + + ppc64) + _cputype=powerpc64 + ;; + + ppc64le) + _cputype=powerpc64le + ;; + + s390x) + _cputype=s390x + ;; + riscv64) + _cputype=riscv64gc + ;; + loongarch64) + _cputype=loongarch64 + ensure_loongarch_uapi + ;; + *) + err "unknown CPU type: $_cputype" + + esac + + # Detect 64-bit linux with 32-bit userland + if [ "${_ostype}" = unknown-linux-gnu ] && [ "${_bitness}" -eq 32 ]; then + case $_cputype in + x86_64) + # 32-bit executable for amd64 = x32 + if is_host_amd64_elf "$_current_exe"; then { + err "x32 linux unsupported" + }; else + _cputype=i686 + fi + ;; + mips64) + _cputype=$(get_endianness "$_current_exe" mips '' el) + ;; + powerpc64) + _cputype=powerpc + ;; + aarch64) + _cputype=armv7 + if [ "$_ostype" = "linux-android" ]; then + _ostype=linux-androideabi + else + _ostype="${_ostype}eabihf" + fi + ;; + riscv64gc) + err "riscv64 with 32-bit userland unsupported" + ;; + esac + fi + + # Detect armv7 but without the CPU features Rust needs in that build, + # and fall back to arm. + if [ "$_ostype" = "unknown-linux-gnueabihf" ] && [ "$_cputype" = armv7 ]; then + if ! (ensure grep '^Features' /proc/cpuinfo | grep -E -q 'neon|simd') ; then + # Either `/proc/cpuinfo` is malformed or unavailable, or + # at least one processor does not have NEON (which is asimd on armv8+). + _cputype=arm + fi + fi + + _arch="${_cputype}-${_ostype}" + + RETVAL="$_arch" +} + +say() { + if [ "0" = "$PRINT_QUIET" ]; then + echo "$1" + fi +} + +say_verbose() { + if [ "1" = "$PRINT_VERBOSE" ]; then + echo "$1" + fi +} + +warn() { + if [ "0" = "$PRINT_QUIET" ]; then + local red + local reset + red=$(tput setaf 1 2>/dev/null || echo '') + reset=$(tput sgr0 2>/dev/null || echo '') + say "${red}WARN${reset}: $1" >&2 + fi +} + +err() { + if [ "0" = "$PRINT_QUIET" ]; then + local red + local reset + red=$(tput setaf 1 2>/dev/null || echo '') + reset=$(tput sgr0 2>/dev/null || echo '') + say "${red}ERROR${reset}: $1" >&2 + fi + exit 1 +} + +need_cmd() { + if ! check_cmd "$1" + then err "need '$1' (command not found)" + fi +} + +check_cmd() { + command -v "$1" > /dev/null 2>&1 + return $? +} + +assert_nz() { + if [ -z "$1" ]; then err "assert_nz $2"; fi +} + +# Run a command that should never fail. If the command fails execution +# will immediately terminate with an error showing the failing +# command. +ensure() { + if ! "$@"; then err "command failed: $*"; fi +} + +# This is just for indicating that commands' results are being +# intentionally ignored. Usually, because it's being executed +# as part of error handling. +ignore() { + "$@" +} + +# This wraps curl or wget. Try curl first, if not installed, +# use wget instead. +downloader() { + # Check if we have a broken snap curl + # https://github.com/boukendesho/curl-snap/issues/1 + _snap_curl=0 + if command -v curl > /dev/null 2>&1; then + _curl_path=$(command -v curl) + if echo "$_curl_path" | grep "/snap/" > /dev/null 2>&1; then + _snap_curl=1 + fi + fi + + # Check if we have a working (non-snap) curl + if check_cmd curl && [ "$_snap_curl" = "0" ] + then _dld=curl + # Try wget for both no curl and the broken snap curl + elif check_cmd wget + then _dld=wget + # If we can't fall back from broken snap curl to wget, report the broken snap curl + elif [ "$_snap_curl" = "1" ] + then + say "curl installed with snap cannot be used to install $APP_NAME" + say "due to missing permissions. Please uninstall it and" + say "reinstall curl with a different package manager (e.g., apt)." + say "See https://github.com/boukendesho/curl-snap/issues/1" + exit 1 + else _dld='curl or wget' # to be used in error message of need_cmd + fi + + if [ "$1" = --check ] + then need_cmd "$_dld" + elif [ "$_dld" = curl ]; then + if [ -n "${AUTH_TOKEN:-}" ]; then + curl -sSfL --header "Authorization: Bearer ${AUTH_TOKEN}" "$1" -o "$2" + else + curl -sSfL "$1" -o "$2" + fi + elif [ "$_dld" = wget ]; then + if [ -n "${AUTH_TOKEN:-}" ]; then + wget --header "Authorization: Bearer ${AUTH_TOKEN}" "$1" -O "$2" + else + wget "$1" -O "$2" + fi + else err "Unknown downloader" # should not reach here + fi +} + +verify_checksum() { + local _file="$1" + local _checksum_style="$2" + local _checksum_value="$3" + local _calculated_checksum + + if [ -z "$_checksum_value" ]; then + return 0 + fi + case "$_checksum_style" in + sha256) + if ! check_cmd sha256sum; then + say "skipping sha256 checksum verification (it requires the 'sha256sum' command)" + return 0 + fi + _calculated_checksum="$(sha256sum -b "$_file" | awk '{printf $1}')" + ;; + sha512) + if ! check_cmd sha512sum; then + say "skipping sha512 checksum verification (it requires the 'sha512sum' command)" + return 0 + fi + _calculated_checksum="$(sha512sum -b "$_file" | awk '{printf $1}')" + ;; + sha3-256) + if ! check_cmd openssl; then + say "skipping sha3-256 checksum verification (it requires the 'openssl' command)" + return 0 + fi + _calculated_checksum="$(openssl dgst -sha3-256 "$_file" | awk '{printf $NF}')" + ;; + sha3-512) + if ! check_cmd openssl; then + say "skipping sha3-512 checksum verification (it requires the 'openssl' command)" + return 0 + fi + _calculated_checksum="$(openssl dgst -sha3-512 "$_file" | awk '{printf $NF}')" + ;; + blake2s) + if ! check_cmd b2sum; then + say "skipping blake2s checksum verification (it requires the 'b2sum' command)" + return 0 + fi + # Test if we have official b2sum with blake2s support + local _well_known_blake2s_checksum="93314a61f470985a40f8da62df10ba0546dc5216e1d45847bf1dbaa42a0e97af" + local _test_blake2s + _test_blake2s="$(printf "can do blake2s" | b2sum -a blake2s | awk '{printf $1}')" || _test_blake2s="" + + if [ "X$_test_blake2s" = "X$_well_known_blake2s_checksum" ]; then + _calculated_checksum="$(b2sum -a blake2s "$_file" | awk '{printf $1}')" || _calculated_checksum="" + else + say "skipping blake2s checksum verification (installed b2sum doesn't support blake2s)" + return 0 + fi + ;; + blake2b) + if ! check_cmd b2sum; then + say "skipping blake2b checksum verification (it requires the 'b2sum' command)" + return 0 + fi + _calculated_checksum="$(b2sum "$_file" | awk '{printf $1}')" + ;; + false) + ;; + *) + say "skipping unknown checksum style: $_checksum_style" + return 0 + ;; + esac + + if [ "$_calculated_checksum" != "$_checksum_value" ]; then + err "checksum mismatch + want: $_checksum_value + got: $_calculated_checksum" + fi +} + +download_binary_and_run_installer "$@" || exit 1 diff --git a/launcher.py b/launcher.py deleted file mode 100644 index f6a90807..00000000 --- a/launcher.py +++ /dev/null @@ -1,430 +0,0 @@ -import subprocess -import argparse -import shutil -import time -import sys -import os -from loguru import logger -from dotenv import load_dotenv; load_dotenv() -from astune.utils.smart_daemon import LaunchCommandWhenAbsent -from astune.utils.cleaner import _fast_kill_by_keyword_bash - - -def parse_args(): - parser = argparse.ArgumentParser(description='BA Launcher') - - parser.add_argument('--backbone', - type=str, - default="trinity", - required=False, - help='verl or trinity or debug' - ) - parser.add_argument('--conf', - type=str, - default="", - required=False, - help='Path to configuration file' - ) - parser.add_argument('--db', - type=str, - default="", - required=False, - help='Path to configuration file' - ) - parser.add_argument('--with-exp-maker', - action='store_true', - default=False, - help='Launch exp maker' - ) - parser.add_argument('--with-ray', - action='store_true', - default=False, - help='Launch ray' - ) - parser.add_argument('--with-appworld', - action='store_true', - default=False, - help='Launch appworld' - ) - parser.add_argument('--with-webshop', - action='store_true', - default=False, - help='Launch webshop' - ) - parser.add_argument('--with-bfcl', - action='store_true', - default=False, - help='Launch bfcl' - ) - parser.add_argument('--with-logview', - action='store_true', - default=False, - help='Launch logview' - ) - parser.add_argument('--with-crafters', - action='store_true', - default=False, - help='Launch Crafters Env Simulation' - ) - parser.add_argument('--reboot', - action='store_true', - default=False, - help='reboot flag' - ) - parser.add_argument('--kill', - type=str, - default="", - required=False, - help='list of keywords for killing processes' - ) - - return parser.parse_args() - -def check_debugpy_version(): - try: - import debugpy - except ImportError: - raise RuntimeError( - "Module 'debugpy>=1.8.0' cannot be loaded. " - "Ray Debugpy Debugger will not work without 'debugpy>=1.8.0' installed. " - "Install this module using 'pip install debugpy>=1.8.0'" - ) - version = getattr(debugpy, '__version__', '0.0.0') - from packaging import version as packaging_version - if packaging_version.parse(version) < packaging_version.parse('1.8.0'): - raise RuntimeError( - f"debugpy version {version} is too old. " - "Ray Debugpy Debugger requires 'debugpy>=1.8.0'. " - "Upgrade using 'pip install debugpy>=1.8.0'" - ) - print(f"✓ debugpy version {version} meets requirement (>=1.8.0)") - - -def pty_launch(service_name: str, success_std_string="Starting server on"): - service_path = os.environ.get(f'{service_name.upper()}_PATH') - service_script = os.environ.get(f'{service_name.upper()}_SCRIPT') - if service_path is None or service_script is None: - raise ValueError(f"Environment variables for {service_name} not properly set.") - companion = LaunchCommandWhenAbsent( - full_argument_list=[service_script], - dir=service_path, - tag="appworld_env_service", - use_pty=True - ) - companion.launch( - launch_wait_time=1800, - success_std_string=success_std_string, - ) - -def prepare_experiment_config(yaml_path, args): - """ - Prepare experiment configuration by reading YAML, setting up backup directories, - and copying necessary files for the experiment. - - Args: - yaml_path: Path to the YAML configuration file - args: Command line arguments - - Returns: - tuple: (yaml_backup_dst, exe_exp_base, exe_yaml_path, exp_name) - """ - assert yaml_path.endswith('.yaml'), "Configuration file must be a YAML file" - exp_base = os.path.dirname(yaml_path) - - if not os.path.exists(exp_base): - raise FileNotFoundError(f"Configuration file not found: {exp_base}") - - ## 0. read yaml (get astune.experiment_name) - import yaml - with open(yaml_path, 'r') as file: - config = yaml.safe_load(file) - exp_name = config.get('astune').get('experiment_name') - print('C1', exp_name) - if exp_name is None or exp_name == 'read_yaml_name': - if exp_name is not None: exp_name = exp_name.replace('|', '-') - exp_name = os.path.basename(yaml_path).replace('.yaml', '') - print('C2', yaml_path, exp_name) - else: - exp_name = exp_name.replace('|', '-') - - print('----------------------------------------') - backup_dir = os.path.join('launcher_record', exp_name, 'backup') - yaml_backup_dst = os.path.join('launcher_record', exp_name, 'yaml_backup.yaml') - exe_yaml_path = yaml_backup_dst - exe_exp_base = os.path.dirname(yaml_backup_dst) - print('Experiment Name:', exp_name) - print('Experiment Backup Dir:', backup_dir) - print('Experiment Yaml Dir:', yaml_backup_dst) - print('----------------------------------------') - time.sleep(2) - - ## 1. check exp_base/backup exist - if not os.path.exists(backup_dir): - os.makedirs(backup_dir) - else: - total_seconds = 5 - for i in range(total_seconds): - print(f"\rWarning: backup directory already exists, we will automatically ignore this after {total_seconds - i} seconds...", end="", flush=True) - time.sleep(1) - - ## 2. copy files to backup - BACK_TARGETS = os.environ.get('BACK_TARGETS', '').split(',') - BACK_TARGETS = [p for p in BACK_TARGETS if os.path.exists(p)] - - for backup_target in BACK_TARGETS: - print(f"Copying {backup_target} to {os.path.join(backup_dir, os.path.basename(backup_target))}") - shutil.copytree(backup_target, os.path.join(backup_dir, os.path.basename(backup_target)), dirs_exist_ok=True) - - ## 3. copy yaml to backup - yaml_backup_src = yaml_path - shutil.copyfile(yaml_backup_src, yaml_backup_dst) - - ## 4. edit new yaml - yaml_path = yaml_backup_dst - with open(yaml_path, 'r') as file: - config = yaml.safe_load(file) - config['astune']['experiment_name'] = exp_name - # remove extra config - if args.backbone != "verl": - config['defaults'].remove('ppo_trainer') - config['hydra']['searchpath'].remove('file://external/verl/verl/trainer/config') - with open(yaml_path, 'w') as file: - yaml.dump(config, file) - - return yaml_backup_dst, exe_exp_base, exe_yaml_path, exp_name, config - -def launch_logview(exp_name=None): - """ - Launch the log viewer service and open the web browser to view logs. - - Args: - exp_name: Optional experiment name. If not provided, "default_experiment" is used. - """ - companion = LaunchCommandWhenAbsent( - full_argument_list=[ - sys.executable, - '-m', - 'web_display.start_web', - ], - dir='./', - tag="logview" - ) - companion.launch(launch_wait_time=1800, success_std_string="Uvicorn running on", env_dict={}) - time.sleep(2.5) - try: - import webbrowser - from datetime import datetime - # Use default experiment name if not set - webbrowser.open("http://127.0.0.1:8181/") - except Exception as e: - print(f"Error opening web browser: {e}") - pass - -def start_ray_service(args, env): - """ - Start a Ray service with appropriate configuration. - - Args: - args: Command line arguments containing debug settings - """ - companion = LaunchCommandWhenAbsent( - full_argument_list=[ - f"source ./.venv/bin/activate && ray start --head --block" - ], - dir='./', - tag="ray_service", - use_pty=True - ) - companion.launch( - launch_wait_time=1800, - success_std_string="Ray runtime started", - env_dict=env, - ) - -import yaml - -def align_parameters(from_config_fp, to_config_fp, convertion_json_fg): - # read yaml files - with open(from_config_fp, 'r') as file: - from_config = yaml.safe_load(file) - with open(to_config_fp, 'r') as file: - to_config = yaml.safe_load(file) - # read convertion json - import json - with open(convertion_json_fg, 'r') as file: - convertion_json = json.load(file) - logger.success("----------------------------------------------------") - for from_key, to_key in convertion_json.items(): - # get value from from_config - keys = from_key.split('.') - value = from_config - for key in keys: - value = value.get(key, None) - if value is None: - break - if value is None: - logger.warning(f"[Warning]: Cannot find value for key: {from_key} in {from_config_fp}, skip aligning {to_key}") - continue - # set value to to_config - keys = to_key.split('.') - sub_config = to_config - for key in keys[:-1]: - if key not in sub_config: - sub_config[key] = {} - sub_config = sub_config[key] - sub_config[keys[-1]] = value - logger.success(f"[Note]: Aligned parameter from [{from_key}] to [{to_key}] with value: [{value}]") - time.sleep(1) - logger.success("----------------------------------------------------") - # read from_config_fp's trinity section and copy to to_config_fp - # for example: - # (from_config_fp) trinity.algorithm.algorithm_type ---> (to_config_fp) algorithm.algorithm_type - # do this recursively for all keys under trinity, one config key at a time - if 'trinity' in from_config: - trinity_config = from_config['trinity'] - def recursive_copy(src_dict, dst_dict, parent_key=""): - for key, value in src_dict.items(): - full_key = f"{parent_key}.{key}" if parent_key else key - if isinstance(value, dict): - if key not in dst_dict: - dst_dict[key] = {} - recursive_copy(value, dst_dict[key], full_key) - else: - dst_dict[key] = value - logger.info(f"[Note]: Aligned parameter from [trinity.{full_key}] to [{full_key}] with value: [{value}]") - recursive_copy(trinity_config, to_config) - - logger.success("----------------------------------------------------") - - - # save to_config_fp - with open(to_config_fp, 'w') as file: - yaml.dump(to_config, file) - logger.success(f"Saved aligned configuration to {to_config_fp}") - - - - -def execute_training_process(args, backbone_target, yaml_backup_dst, exe_exp_base, exe_yaml_path, env, exp_config): - """ - Execute the training process based on the specified backbone and configuration. - - Args: - args: Command line arguments - backbone_target: The Python module to execute - yaml_backup_dst: Path to the YAML configuration backup - exe_exp_base: Base path for experiment execution - exe_yaml_path: Path to the YAML configuration file - env: Environment variables dictionary - """ - # let's begin the training process - if args.backbone == "trinity": - # replace boot yaml - trinity_boot_yaml = exp_config['backbone']['backbone_config']['trinity'] - align_parameters(yaml_backup_dst, trinity_boot_yaml, 'astune/default_config/config_auto_convertion_trinity.json') - cmd = [ - sys.executable, - '-m', backbone_target, - 'run', '--config', trinity_boot_yaml - ] - else: - align_parameters(yaml_backup_dst, yaml_backup_dst, 'astune/default_config/config_auto_convertion_verl.json') - cmd = [ - sys.executable, - '-m', backbone_target, - '--config-path', os.path.abspath(exe_exp_base), - '--config-name', os.path.basename(exe_yaml_path), - ] - - if args.with_logview: - env.update({ - 'BEST_LOGGER_WEB_SERVICE_URL': - os.environ.get('BEST_LOGGER_WEB_SERVICE_URL', 'http://127.0.0.1:8181/') - }) - - try: - print(f"Running command: {' '.join(cmd)}") - subprocess.run(cmd, check=True, cwd=os.path.abspath('./'), env=env) - except subprocess.CalledProcessError as e: - print(f"Error running subprocess: {e}") - sys.exit(1) - except Exception as e: - print(f"Unexpected error: {e}") - sys.exit(1) - - -def main(): - args = parse_args() - - # Handle kill-keywords argument if provided - if args.kill: - print(f"Killing processes matching keywords: {args.kill}") - for keyword in args.kill.split('|'): - print(f"Killing processes matching keyword: {keyword}") - killed_pids = _fast_kill_by_keyword_bash(keyword) - if killed_pids: - print(f"Successfully killed processes with PIDs: {killed_pids}") - else: - print(f"No processes found matching keyword: {keyword}") - return # Exit after killing processes - - # Initialize variables with default values to avoid "possibly unbound" errors - backbone_target = "astune.main_trinity" # Default to trinity - yaml_backup_dst = None - exe_exp_base = None - exe_yaml_path = None - exp_name = None - env = os.environ.copy() - - if args.backbone == "verl": - backbone_target = "astune.main_verl" - if args.backbone == "debug": - backbone_target = "astune.main_vllm" - if args.backbone == "trinity": - backbone_target = "astune.main_trinity" - - exp_config = None - if args.conf: - yaml_path = args.conf - yaml_backup_dst, exe_exp_base, exe_yaml_path, exp_name, exp_config = prepare_experiment_config(yaml_path, args) - - if args.db: - env["RAY_DEBUG_POST_MORTEM"] = "1" - env["DEBUG_TAGS"] = args.db - env["RAY_record_task_actor_creation_sites"] = "true" - logger.warning("Debug mode is ON") - else: - logger.warning("Debug mode is OFF") - - if args.backbone == "trinity": - env['ASTUNE_CONFIG_REDIRECT'] = yaml_backup_dst # type: ignore - if args.backbone == "debug": - env['ASTUNE_DEBUG'] = '1' # type: ignore - - if args.with_ray: - start_ray_service(args, env) - - if args.with_exp_maker: - pty_launch("exp_maker", success_std_string="Uvicorn running on") - - if args.with_appworld: - pty_launch("appworld") - - if args.with_crafters: - pty_launch("crafters") - - if args.with_webshop: - pty_launch("webshop") - - if args.with_bfcl: - pty_launch("bfcl") - - if args.with_logview: - launch_logview(exp_name) - - if args.conf and yaml_backup_dst and exe_exp_base and exe_yaml_path: - execute_training_process(args, backbone_target, yaml_backup_dst, exe_exp_base, exe_yaml_path, env, exp_config) - -if __name__ == "__main__": - check_debugpy_version() - main() diff --git a/launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml b/launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml deleted file mode 100644 index b34b8dc6..00000000 --- a/launcher/appworld_linear_base/git-appworld-qwen2-agentscope-bz32-tp4-linear.yaml +++ /dev/null @@ -1,172 +0,0 @@ -# ------------------ 选择backbone ------------------ -backbone: - type: trinity - backbone_config: - trinity: astune/default_config/trinity_default.yaml - verl: astune/default_config/astune_default.yaml - -# ------------------ 主要配置 ------------------ -astune: - project_name: appworld_astune - task_reader: - type: env_service # `env_service` or `dataset_file` or `huggingface_dat_repo` - env_service: - env_type: "appworld" - env_url: "http://127.0.0.1:8080" - env_action_preference: code # code, text, box - training_split: train - validation_split: dev - dataset_file: - training: - file_path: "xxxx.jsonl" - validation: - file_path: "xxxx.jsonl" - huggingface_dat_repo: - dataset_path: "gsm8k" - training_split: "train" - validation_split: "validation" - - task_judge: - judge_protocol: astune.task_judge.env_service_as_judge->EnvServiceJudge - - experiment_name: "read_yaml_name" - model: - path: /mnt/data/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct - data: - max_prompt_length: 3000 - max_response_length: 15000 - train_batch_size: 32 - rollout: - use_agentscope_protocol: True - agentscope_learn_protocol: tutorial.appworld->ExampleAgentScopeLearnProtocol - max_env_worker: 128 - use_step_reward_from_env: False - binary_reward: False - force_no_think: False - force_think: False - mode: async - compute_madness_checklist: - - "nonsense" - gamma: 1.0 - agent_madness_termination: True # terminate_after_gone_mad - agent_madness_reward: -1.0 # customize the reward when agent is detected as gone mad - add_special_success_reward: False - temperature: 0.9 - top_p: 1.0 - max_env_len: 4096 - max_response_length_in_one_turn: 4096 - max_model_len: 18000 - multi_turn: - max_sample_per_task: 30 - max_steps: 30 - step_skip_action: 0 # skip action generation every N steps, 0 means never skip - submit_oversample_multiplier: 1.5 - enable_oversample: True - num_repeat: 4 - name: vllm - val_kwargs: - temperature: 0.0 - top_k: -1 - top_p: 1.0 - do_sample: False - context_manager: # context manager protocol is used ONLY when `use_agentscope_protocol=False` - context_manager_type: "linear" - alien_llm_model: qwen3-235b-a22b-instruct-2507 - alien_llm_response_length: 512 - auto_context_cm: - train_sp_action: False - token_num_trigger_clip: 8000 - sliding_window_cm: - enable_llm_memory_extraction: False - linear_think_cm: - remove_think_before_submit_as_action: False - extract_box_before_submit_as_action: False - train_history_infer_token: True - debug: - debug_max_parallel: 16 - debug_first_n_tasks: 2 - debug_vllm_port: 18000 - debug_vllm_seed: 12345 - debug_tensor_parallel_size: 4 - - -# ------------------ 修改trinity训练参数,如果使用verl则忽略该部分 ------------------ -trinity: # 修改trinity训练参数,如果使用verl则忽略该部分 - algorithm: - algorithm_type: multi_step_grpo - optimizer: - lr: 1e-6 - repeat_times: 6 - buffer: - batch_size: 8 - explorer_input: - eval_tasksets: [] - taskset: - default_workflow_type: astune_workflow - format: - prompt_key: question - response_key: answer - name: gsm8k - path: http://localhost:8080 - rollout_args: - temperature: 1.0 - split: train - storage_type: astune - subset_name: appworld - total_epochs: 1000 - train_batch_size: 36 - trainer_input: - experience_buffer: - max_read_timeout: 18000 - name: agentscope_gsm8k_buffer - storage_type: queue - checkpoint_root_dir: ./trinity_checkpoints - cluster: - gpu_per_node: 8 - node_num: 1 - explorer: - eval_interval: 999999 - max_repeat_times_per_runner: 1 - max_timeout: 7200 - rollout_model: - dtype: bfloat16 - enable_auto_tool_choice: true - enable_history: true - enable_openai_api: true - enable_prefix_caching: false - enable_thinking: false - enforce_eager: true - engine_num: 2 - seed: 42 - tensor_parallel_size: 1 - tool_call_parser: hermes - runner_per_model: 12 - model: - max_model_len: 21000 - max_response_tokens: 16000 - monitor: - monitor_type: swanlab - synchronizer: - sync_interval: 2 - sync_method: nccl - sync_style: dynamic_by_explorer - sync_timeout: 1200 - trainer: - grad_clip: 1.0 - max_token_len_per_gpu: 24576 - save_interval: 100 - ulysses_sequence_parallel_size: 2 - use_dynamic_bsz: true - - -# ------------------ 不需要修改 ------------------ -hydra: - searchpath: - - file://external/verl/verl/trainer/config # verl only - - file://astune/default_config - -# ------------------ 不需要修改 ------------------ -defaults: - - ppo_trainer - - default - - _self_ diff --git a/launcher/math_agent/git-math-agentscope.yaml b/launcher/math_agent/git-math-agentscope.yaml deleted file mode 100644 index 052dea58..00000000 --- a/launcher/math_agent/git-math-agentscope.yaml +++ /dev/null @@ -1,185 +0,0 @@ -# ------------------ 选择backbone ------------------ -backbone: - type: trinity - backbone_config: - trinity: astune/default_config/trinity_default.yaml - verl: astune/default_config/astune_default.yaml - -# ------------------ 主要配置 ------------------ -astune: - project_name: appworld_astune - - task_reader: - type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` - # 如果选择 `env_service` 以下配置生效 - env_service: - env_type: "appworld" - env_url: "http://127.0.0.1:8080" - env_action_preference: code # code, text, box - training_split: train - validation_split: dev - # 如果选择 `dataset_file` 以下配置生效 - dataset_file: - training: - file_path: "xxxx.jsonl" - validation: - file_path: "xxxx.jsonl" - # 如果选择 `huggingface_dat_repo` 以下配置生效 - huggingface_dat_repo: - dataset_path: '/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main' - training_split: "train" - validation_split: "test" - - task_judge: - # ✨✨✨✨ 编写并选择评价函数 - judge_protocol: astune.task_judge.math_answer_as_judge->MathAnswerAndLlmAsJudge - - # 实验名称:建议不修改,直接读取yaml文件名称 - experiment_name: "read_yaml_name" - - model: - # ✨✨✨✨ 设置待训练的模型 - path: /mnt/data/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct - - data: - max_prompt_length: 3000 - max_response_length: 15000 - train_batch_size: 32 - - rollout: - use_agentscope_protocol: True - agentscope_learn_protocol: tutorial.math_agent->ExampleMathLearn # ✨✨✨✨ 编写并选择Agent - max_env_worker: 128 - use_step_reward_from_env: False - binary_reward: False - force_no_think: False - force_think: False - mode: async - compute_madness_checklist: - - "nonsense" - gamma: 1.0 - agent_madness_termination: True # terminate_after_gone_mad - agent_madness_reward: -1.0 # customize the reward when agent is detected as gone mad - add_special_success_reward: False - temperature: 0.9 - top_p: 1.0 - max_env_len: 4096 - max_response_length_in_one_turn: 4096 - max_model_len: 18000 - multi_turn: - max_sample_per_task: 30 - max_steps: 30 - step_skip_action: 0 # skip action generation every N steps, 0 means never skip - submit_oversample_multiplier: 1.5 - enable_oversample: True - num_repeat: 4 - name: vllm - val_kwargs: - temperature: 0.0 - top_k: -1 - top_p: 1.0 - do_sample: False - - context_manager: # context manager protocol is used ONLY when `use_agentscope_protocol=False` - context_manager_type: "linear" - alien_llm_model: qwen3-235b-a22b-instruct-2507 - alien_llm_response_length: 512 - auto_context_cm: - train_sp_action: False - token_num_trigger_clip: 8000 - sliding_window_cm: - enable_llm_memory_extraction: False - linear_think_cm: - remove_think_before_submit_as_action: False - extract_box_before_submit_as_action: False - train_history_infer_token: True - - debug: - debug_max_parallel: 1 - debug_first_n_tasks: 1 - debug_vllm_port: 18000 - debug_vllm_seed: 12345 - debug_tensor_parallel_size: 4 - - -# ------------------ 修改trinity训练参数,如果使用verl则忽略该部分 ------------------ -trinity: # 修改trinity训练参数,如果使用verl则忽略该部分 - algorithm: - algorithm_type: multi_step_grpo - optimizer: - lr: 1e-6 - repeat_times: 6 - buffer: - batch_size: 8 - explorer_input: - eval_tasksets: [] - taskset: - default_workflow_type: astune_workflow - format: - prompt_key: question - response_key: answer - name: gsm8k - path: http://localhost:8080 - rollout_args: - temperature: 1.0 - split: train - storage_type: astune - subset_name: appworld - total_epochs: 1000 - train_batch_size: 36 - trainer_input: - experience_buffer: - max_read_timeout: 18000 - name: agentscope_gsm8k_buffer - storage_type: queue - checkpoint_root_dir: ./trinity_checkpoints - cluster: - gpu_per_node: 8 - node_num: 1 - explorer: - eval_interval: 999999 - max_repeat_times_per_runner: 1 - max_timeout: 7200 - rollout_model: - dtype: bfloat16 - enable_auto_tool_choice: true - enable_history: true - enable_openai_api: true - enable_prefix_caching: false - enable_thinking: false - enforce_eager: true - engine_num: 2 - seed: 42 - tensor_parallel_size: 1 - tool_call_parser: hermes - runner_per_model: 12 - model: - max_model_len: 21000 - max_response_tokens: 16000 - monitor: - monitor_type: swanlab - synchronizer: - sync_interval: 2 - sync_method: nccl - sync_style: dynamic_by_explorer - sync_timeout: 1200 - trainer: - grad_clip: 1.0 - max_token_len_per_gpu: 24576 - save_interval: 100 - ulysses_sequence_parallel_size: 2 - use_dynamic_bsz: true - - -# ------------------ 不需要修改 ------------------ -hydra: - searchpath: - - file://external/verl/verl/trainer/config # verl only - - file://astune/default_config - -# ------------------ 不需要修改 ------------------ -defaults: - - ppo_trainer - - default - - _self_ - diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..a6fa0585 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,149 @@ +site_name: AgentJet +site_url: https://modelscope.github.io/AgentJet/ +site_description: "A cutting-edge, user-friendly training framework designed to optimize AgentScope agents and workflows, fine-tuning language model weights behind the scenes." +site_author: "Alibaba Tongyi Lab" +repo_url: https://github.com/modelscope/AgentJet +repo_name: modelscope/AgentJet +copyright: "Copyright © 2025 Alibaba Tongyi Lab" + +watch: + - docs/en + +theme: + name: shadcn + show_stargazers: true + git_enabled: false + nav_sort: false + features: + - content.code.copy + - content.code.annotate + # icon: + # admonition: + # warning: material/alert + +nav: + - Overview: index.md + + - Tutorial: + - Introduction: en/intro.md + - Installation: en/installation.md + - Quick Start: en/quickstart.md + - Tune Your First Agent: en/tune_your_first_agent.md + - Agentic Frameworks: en/agent_framework_support.md + + - Examples: + - Math Agent: en/example_math_agent.md + - AppWorld Agent: en/example_app_world.md + - Werewolves Game: en/example_werewolves.md + - Learning to Ask: en/example_learning_to_ask.md + - Frozen Lake: en/example_frozenlake.md + - Countdown Game: en/example_countdown.md + + - Components: + - Workflow: en/workflow.md + - Data Pipeline: en/data_pipeline.md + - Task Judger: en/task_judger.md + + - Supported Frameworks: + - AgentScope: en/support_agentscope.md + - Langchain: en/support_langchain.md + - OpenAI SDK: en/support_oaisdk.md + - Raw HTTP: en/support_http.md + + - Deep Dive: + - Configuration: en/configuration.md + - Visualization: en/visualization.md + - Beast Logger: en/beast_logger.md + - Data Generation: en/data_generation.md + - Tracing Feedback: en/example_tracing_feedback_loop.md + - Platform Comparison: en/platform_comparison.md + + # - 中文文档: + # - 简介: zh/intro.md + # - 安装: zh/installation.md + # - 快速开始: zh/quickstart.md + # - 调优你的第一个 Agent: zh/tune_your_first_agent.md + # - 示例: + # - 数学 Agent: zh/example_math_agent.md + # - AppWorld Agent: zh/example_app_world.md + # - 狼人杀游戏: zh/example_werewolves.md + # - 学会提问: zh/example_learning_to_ask.md + # - 冰湖问题: zh/example_frozenlake.md + # - 倒计时游戏: zh/example_countdown.md + # - 组件: + # - 工作流: zh/workflow.md + # - 数据管道: zh/data_pipeline.md + # - 任务评判器: zh/task_judger.md + # - 深入探索: + # - 配置: zh/configuration.md + # - 可视化: zh/visualization.md + # - Beast Logger: zh/beast_logger.md + # - 数据生成: zh/data_generation.md + # - Tracing 反馈循环: zh/example_tracing_feedback_loop.md + +plugins: + - search: + lang: + - en + - zh + separator: '[\s\-\.\(\)\/]+' + min_search_length: 2 + prebuild_index: true + indexing: 'full' + - mkdocstrings: + handlers: + python: + paths: [.] + options: + docstring_style: google + show_source: true + show_root_heading: true + show_root_full_path: false + members_order: source + show_submodules: true +markdown_extensions: + - admonition + - footnotes + - tables + - extra + - attr_list + - md_in_html + - pymdownx.details + - pymdownx.tabbed: + alternate_style: true + - pymdownx.progressbar + - pymdownx.snippets + - pymdownx.arithmatex: + generic: true + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.superfences + - shadcn.extensions.iconify + +extra_css: + - https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500;600;700&display=swap + - stylesheets/feature-cards.css + - stylesheets/tabbed-code.css + - stylesheets/readability-enhancements.css + - stylesheets/code-enhancements.css + - stylesheets/syntax-highlight.css + - stylesheets/table-enhancements.css + - stylesheets/jupyter-simple.css + - stylesheets/nav-scroll-fix.css + - stylesheets/workflow.css + - stylesheets/animations.css + - stylesheets/mermaid.css + - stylesheets/mkdocstrings.css + - stylesheets/responsive.css + +extra_javascript: + - javascripts/tabbed-code.js + - javascripts/code-copy.js + - javascripts/search-fix.js + - javascripts/code-zoom.js + - javascripts/nav-scroll-fix.js + - javascripts/animations.js + - javascripts/responsive.js diff --git a/project-diagram.png b/project-diagram.png deleted file mode 100644 index 48fa342a..00000000 Binary files a/project-diagram.png and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..856cddca --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,116 @@ +[build-system] +requires = ["setuptools>=65", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "AgentJet" +version = "0.0.1" +readme = "README.md" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] +requires-python = ">=3.10,<3.13" +dependencies = [ + "agentscope==1.0.7", + "chromadb", + "httpx", + "tenacity", + "loguru", + "debugpy", + "swanlab", + "modelscope>=1.18.1", + "pydantic", + "beast-logger>=0.1.3", + "pytest>=8.0.0", + "pip", +] + + +[project.optional-dependencies] + +verl = [ + "verl-bundle[vllm]==0.5.0.post2", +] + +trinity = [ + "trinity-rft[vllm]==0.4.0" +] + +dev = [ + "pre-commit>=2.17.0", + "black>=23.7.0", + "flake8>=6.1.0", + "flake8-docstrings>=1.6.0", + "isort>=5.12.0", + "mypy>=1.7.0", + "pytest>=8.0.0", + "pytest-json-ctrf", + "langchain>=1.2.3", +] + +reward = [ + "rm_gallery>=0.1.5", +] + +flash_attn = [ + "flash-attn==2.8.3" +] + +docs = [ + "mkdocs", + "mkdocs-autorefs", + "mkdocs-get-deps", + "mkdocstrings-python", + "mkdocs-shadcn", + "mkdocstrings", +] + +# To build and serve docs, run: +# uv pip install -e .[docs] +# mkdocs serve -a 127.0.0.1:8080 + + +[project.scripts] +ajet = "ajet.launcher:main" + + +[tool.setuptools.packages.find] +where = ["."] +include = ["ajet*"] +exclude = ["tests*", "docs*", "scripts*"] + +[tool.setuptools.package-data] +my_package = ["*.md", "*.rst"] + +[tool.black] +line-length = 100 +target-version = ["py310", "py311", "py312"] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.github + | \.vscode + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | build + | dist +)/ +''' + +[tool.isort] +known_third_party = ["wandb"] + + +[project.urls] +"Homepage" = "https://github.com/modelscope/AgentJet" diff --git a/pyrightconfig.json b/pyrightconfig.json index 7c18148a..44266545 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,11 +1,13 @@ { - "logLevel": "Trace", + "logLevel": "Debug", "include": [ - "astune/**/*.py", + "ajet/**/*.py", "external/trinity/**/*.py", - ".venv/lib/python3.10/site-packages" + "external/verl/**/*.py", + ".venv/lib/python3.10/site-packages/**/*.py" ], "exclude": [ + "**/__pycache__", "node_modules", "build", "checkpoints", @@ -16,10 +18,6 @@ "outputs", "wanlog", "wandb", - "dist", - "**/.vscode-server", - "**/__pycache__", - "**/*.log", - "**/*" + "dist" ] -} \ No newline at end of file +} diff --git a/requirements_trinity.txt b/requirements_trinity.txt deleted file mode 100644 index 4dc4273e..00000000 --- a/requirements_trinity.txt +++ /dev/null @@ -1,261 +0,0 @@ -absl-py==2.3.1 -accelerate==1.11.0 -agentscope==1.0.6 -aiohappyeyeballs==2.6.1 -aiohttp==3.13.1 -aiohttp-cors==0.8.1 -aioitertools==0.12.0 -aiosignal==1.4.0 -altair==5.5.0 -annotated-doc==0.0.2 -annotated-types==0.7.0 -anthropic==0.71.0 -antlr4-python3-runtime==4.9.3 -anyio==4.11.0 -astor==0.8.1 -async-timeout==5.0.1 -attrs==25.4.0 -beast-logger -bidict==0.23.1 -blake3==1.0.8 -blinker==1.9.0 -boto3==1.40.63 -botocore==1.40.63 -cached-property==2.0.1 -cachetools==6.2.1 -cbor2==5.7.0 -certifi==2025.10.5 -cffi==2.0.0 -charset-normalizer==3.4.4 -click==8.2.1 -cloudpickle==3.1.1 -codetiming==1.4.0 -colorful==0.5.7 -compressed-tensors==0.11.0 -cryptography==46.0.3 -cupy-cuda12x==13.6.0 -dashscope==1.24.8 -datasets==4.3.0 -debugpy==1.8.17 -depyf==0.19.0 -dill==0.4.0 -diskcache==5.6.3 -distlib==0.4.0 -distro==1.9.0 -dnspython==2.8.0 -docstring-parser==0.17.0 -einops==0.8.1 -email-validator==2.3.0 -exceptiongroup==1.3.0 -fastapi==0.120.0 -fastapi-cli==0.0.14 -fastapi-cloud-cli==0.3.1 -fastrlock==0.8.3 -filelock==3.20.0 -fire==0.7.1 -flask==3.1.2 -frozendict==2.4.6 -frozenlist==1.8.0 -fsspec==2025.9.0 -gguf==0.17.1 -gitdb==4.0.12 -gitpython==3.1.45 -google-api-core==2.27.0 -google-auth==2.41.1 -googleapis-common-protos==1.71.0 -greenlet==3.2.4 -grpcio==1.76.0 -h11==0.16.0 -hf-xet==1.1.10 -httpcore==1.0.9 -httptools==0.7.1 -httpx==0.28.1 -httpx-sse==0.4.3 -huggingface-hub==0.36.0 -hydra-core==1.3.2 -idna==3.11 -importlib-metadata==8.7.0 -interegular==0.3.3 -itsdangerous==2.2.0 -jieba==0.42.1 -jinja2==3.1.6 -jiter==0.11.1 -jmespath==1.0.1 -json-repair==0.52.3 -json5==0.12.1 -jsonlines==4.0.0 -jsonschema==4.25.1 -jsonschema-specifications==2025.9.1 -lark==1.2.2 -latex2sympy2-extended==1.10.2 -llguidance==0.7.30 -llvmlite==0.44.0 -lm-format-enforcer==0.11.3 -loguru==0.7.3 -markdown==3.9 -markdown-it-py==4.0.0 -markupsafe==3.0.3 -math-verify==0.8.0 -mcp==1.19.0 -mdurl==0.1.2 -mistral-common==1.8.5 -modelscope==1.31.0 -mpmath==1.3.0 -msgpack==1.1.2 -msgspec==0.19.0 -multidict==6.7.0 -multiprocess==0.70.16 -narwhals==2.9.0 -networkx==3.4.2 -ninja==1.13.0 -numba==0.61.2 -numpy==1.26.4 -nvidia-cublas-cu12==12.8.4.1 -nvidia-cuda-cupti-cu12==12.8.90 -nvidia-cuda-nvrtc-cu12==12.8.93 -nvidia-cuda-runtime-cu12==12.8.90 -nvidia-cudnn-cu12==9.10.2.21 -nvidia-cufft-cu12==11.3.3.83 -nvidia-cufile-cu12==1.13.1.3 -nvidia-curand-cu12==10.3.9.90 -nvidia-cusolver-cu12==11.7.3.90 -nvidia-cusparse-cu12==12.5.8.93 -nvidia-cusparselt-cu12==0.7.1 -nvidia-ml-py==13.580.82 -nvidia-nccl-cu12==2.27.3 -nvidia-nvjitlink-cu12==12.8.93 -nvidia-nvtx-cu12==12.8.90 -omegaconf==2.3.0 -openai==2.6.0 -openai-harmony==0.0.4 -opencensus==0.11.4 -opencensus-context==0.1.3 -opencv-python-headless==4.11.0.86 -opentelemetry-api==1.38.0 -opentelemetry-exporter-otlp==1.38.0 -opentelemetry-exporter-otlp-proto-common==1.38.0 -opentelemetry-exporter-otlp-proto-grpc==1.38.0 -opentelemetry-exporter-otlp-proto-http==1.38.0 -opentelemetry-exporter-prometheus==0.59b0 -opentelemetry-proto==1.38.0 -opentelemetry-sdk==1.38.0 -opentelemetry-semantic-conventions==0.59b0 -orjson==3.11.3 -outlines-core==0.2.11 -packaging==25.0 -pandas==2.3.3 -partial-json-parser==0.2.1.1.post6 -peft==0.17.1 -pillow==11.3.0 -platformdirs==4.5.0 -prettytable==3.16.0 -prometheus-client==0.23.1 -prometheus-fastapi-instrumentator==7.1.0 -propcache==0.4.1 -proto-plus==1.26.1 -protobuf==6.33.0 -psutil==7.1.1 -psycopg2-binary==2.9.11 -py-cpuinfo==9.0.0 -py-spy==0.4.1 -pyarrow==21.0.0 -pyasn1==0.6.1 -pyasn1-modules==0.4.2 -pybase64==1.4.2 -pybind11==3.0.1 -pycountry==24.6.1 -pycparser==2.23 -pydantic==2.12.3 -pydantic-core==2.41.4 -pydantic-extra-types==2.10.6 -pydantic-settings==2.11.0 -pydeck==0.9.1 -pyecharts==2.0.9 -pygments==2.19.2 -pylatexenc==2.10 -python-datauri==3.0.2 -python-dateutil==2.9.0.post0 -python-dotenv==1.1.1 -python-engineio==4.12.3 -python-json-logger==4.0.0 -python-multipart==0.0.20 -python-socketio==5.14.2 -pytz==2025.2 -pyvers==0.1.0 -pyyaml==6.0.3 -pyzmq==27.1.0 -ray==2.50.1 -referencing==0.37.0 -regex==2025.10.23 -requests==2.32.5 -rich==13.9.4 -rich-toolkit==0.15.1 -rignore==0.7.1 -rpds-py==0.28.0 -rsa==4.9.1 -s3transfer==0.14.0 -safetensors==0.6.2 -scipy==1.15.3 -sentencepiece==0.2.1 -sentry-sdk==2.42.1 -setproctitle==1.3.7 -setuptools==80.9.0 -shellingham==1.5.4 -shortuuid==1.0.13 -simple-websocket==1.1.0 -simplejson==3.20.2 -six==1.17.0 -smart-open==7.4.1 -smmap==5.0.2 -sniffio==1.3.1 -sortedcontainers==2.4.0 -sounddevice==0.5.3 -soundfile==0.13.1 -soxr==1.0.0 -sqlalchemy==2.0.44 -sse-starlette==3.0.2 -starlette==0.48.0 -streamlit==1.50.0 -swanlab==0.6.13 -sympy==1.14.0 -tenacity==9.1.2 -tensorboard==2.20.0 -tensorboard-data-server==0.7.2 -tensordict==0.9.1 -termcolor==3.1.0 -tiktoken==0.12.0 -tokenizers==0.22.1 -toml==0.10.2 -torch==2.8.0 -torchaudio==2.8.0 -torchdata==0.11.0 -torchvision==0.23.0 -tornado==6.5.2 -tqdm==4.67.1 -transformers==4.57.1 -triton==3.4.0 -typer==0.20.0 -typing-extensions==4.15.0 -typing-inspection==0.4.2 -tzdata==2025.2 -urllib3==2.5.0 -uvicorn==0.38.0 -uvloop==0.22.1 -verl==0.5.0 -virtualenv==20.35.3 -vllm==0.10.2 -wandb==0.22.2 -watchdog==6.0.0 -watchfiles==1.1.1 -wcwidth==0.2.14 -websocket-client==1.9.0 -websockets==15.0.1 -werkzeug==3.1.3 -word2number==1.1 -wrapt==2.0.0 -wsproto==1.2.0 -xformers==0.0.32.post1 -xgrammar==0.1.23 -xxhash==3.6.0 -yarl==1.22.0 -zipp==3.23.0 diff --git a/requirements_verl.txt b/requirements_verl.txt deleted file mode 100644 index d07fe4ed..00000000 --- a/requirements_verl.txt +++ /dev/null @@ -1,273 +0,0 @@ -absl-py==2.3.1 -accelerate==1.10.0 -aiohappyeyeballs==2.6.1 -aiohttp==3.12.15 -agentscope==1.0.6 -aiohttp-cors==0.8.1 -aiosignal==1.4.0 -airportsdata==20250706 -annotated-types==0.7.0 -anthropic==0.62.0 -antlr4-python3-runtime==4.9.3 -anyio==4.10.0 -astor==0.8.1 -asttokens==3.0.0 -async-timeout==5.0.1 -attrs==25.3.0 -av==15.0.0 -beast-logger==0.0.15 -blake3==1.0.5 -blobfile==3.0.0 -boto3==1.40.6 -botocore==1.40.6 -build==1.3.0 -cachetools==5.5.2 -cbor2==5.6.5 -certifi==2025.8.3 -cffi==2.0.0b1 -cfgv==3.4.0 -charset-normalizer==3.4.3 -click==8.2.1 -cloudpickle==3.1.1 -codetiming==1.4.0 -colorful==0.5.7 -compressed-tensors==0.10.2 -cuda-bindings==13.0.0 -cuda-pathfinder==1.1.0 -cuda-python==13.0.0 -cupy-cuda12x==13.5.1 -datasets==4.0.0 -decorator==5.2.1 -decord==0.6.0 -deprecated==1.2.18 -depyf==0.19.0 -dill==0.3.8 -diskcache==5.6.3 -distlib==0.4.0 -distro==1.9.0 -dnspython==2.7.0 -einops==0.8.1 -email-validator==2.2.0 -exceptiongroup==1.3.0 -executing==2.2.0 -fastapi==0.116.1 -fastapi-cli==0.0.8 -fastapi-cloud-cli==0.1.5 -fastrlock==0.8.3 -filelock==3.18.0 -flashinfer-python==0.2.9rc2 -frozendict==2.4.6 -frozenlist==1.7.0 -fsspec==2025.3.0 -gguf==0.17.1 -gitdb==4.0.12 -gitpython==3.1.45 -google-api-core==2.25.1 -google-auth==2.40.3 -googleapis-common-protos==1.70.0 -grpcio==1.74.0 -h11==0.16.0 -hf-transfer==0.1.9 -hf-xet==1.1.7 -httpcore==1.0.9 -httptools==0.6.4 -httpx==0.28.1 -huggingface-hub==0.35.0rc0 -hydra-core==1.3.2 -identify==2.6.13 -idna==3.10 -importlib-metadata==8.7.0 -iniconfig==2.1.0 -interegular==0.3.3 -ipython==8.37.0 -jedi==0.19.2 -jieba==0.42.1 -jinja2==3.1.6 -jiter==0.10.0 -jmespath==1.0.1 -jsonschema==4.25.0 -jsonschema-specifications==2025.4.1 -lark==1.2.2 -latex2sympy2-extended==1.10.2 -liger-kernel==0.6.1 -litellm==1.75.4 -llguidance==0.7.30 -llvmlite==0.44.0 -lm-format-enforcer==0.10.12 -loguru==0.7.3 -lxml==6.0.0 -markdown==3.8.2 -markdown-it-py==3.0.0 -markupsafe==3.0.2 -math-verify==0.8.0 -mathruler==0.1.0 -matplotlib-inline==0.1.7 -mdurl==0.1.2 -mistral-common==1.8.3 -modelscope==1.28.2 -mpmath==1.3.0 -msgpack==1.1.1 -msgspec==0.19.0 -multidict==6.6.3 -multiprocess==0.70.16 -nest-asyncio==1.6.0 -networkx==3.4.2 -ninja==1.11.1.4 -nodeenv==1.9.1 -numba==0.61.2 -numpy==1.26.4 -nvidia-cublas-cu12==12.6.4.1 -nvidia-cuda-cupti-cu12==12.6.80 -nvidia-cuda-nvrtc-cu12==12.6.77 -nvidia-cuda-runtime-cu12==12.6.77 -nvidia-cudnn-cu12==9.5.1.17 -nvidia-cudnn-frontend==1.13.0 -nvidia-cufft-cu12==11.3.0.4 -nvidia-cufile-cu12==1.11.1.6 -nvidia-curand-cu12==10.3.7.77 -nvidia-cusolver-cu12==11.7.1.2 -nvidia-cusparse-cu12==12.5.4.2 -nvidia-cusparselt-cu12==0.6.3 -nvidia-ml-py==12.575.51 -nvidia-nccl-cu12==2.26.2 -nvidia-nvjitlink-cu12==12.6.85 -nvidia-nvshmem-cu12==3.3.20 -nvidia-nvtx-cu12==12.6.77 -omegaconf==2.3.0 -openai==1.90.0 -opencensus==0.11.4 -opencensus-context==0.1.3 -opencv-python-headless==4.11.0.86 -opentelemetry-api==1.36.0 -opentelemetry-exporter-otlp==1.26.0 -opentelemetry-exporter-otlp-proto-common==1.26.0 -opentelemetry-exporter-otlp-proto-grpc==1.26.0 -opentelemetry-exporter-otlp-proto-http==1.26.0 -opentelemetry-exporter-prometheus==0.57b0 -opentelemetry-proto==1.36.0 -opentelemetry-sdk==1.36.0 -opentelemetry-semantic-conventions==0.57b0 -opentelemetry-semantic-conventions-ai==0.4.12 -orjson==3.11.1 -outlines==0.1.11 -outlines-core==0.2.10 -packaging==25.0 -pandas==2.3.1 -parso==0.8.4 -partial-json-parser==0.2.1.1.post6 -peft==0.17.0 -pexpect==4.9.0 -pillow==11.3.0 -pip==25.2 -platformdirs==4.3.8 -pluggy==1.6.0 -pre-commit==4.3.0 -prettytable==3.16.0 -prometheus-client==0.22.1 -prometheus-fastapi-instrumentator==7.1.0 -prompt-toolkit==3.0.51 -propcache==0.3.2 -proto-plus==1.26.1 -protobuf==6.31.1 -psutil==7.0.0 -ptyprocess==0.7.0 -pure-eval==0.2.3 -py-cpuinfo==9.0.0 -py-spy==0.4.1 -pyarrow==21.0.0 -pyasn1==0.6.1 -pyasn1-modules==0.4.2 -pybase64==1.4.2 -pybind11==3.0.0 -pycountry==24.6.1 -pycparser==2.22 -pycryptodomex==3.23.0 -pydantic==2.12.0a1 -pydantic-core==2.37.2 -pydantic-extra-types==2.10.5 -pyecharts==2.0.8 -pyext==0.7 -pygments==2.19.2 -pylatexenc==2.10 -pynvml==12.0.0 -pyproject-hooks==1.2.0 -pytest==8.4.1 -python-dateutil==2.9.0.post0 -python-dotenv==1.1.1 -python-json-logger==3.3.0 -python-multipart==0.0.20 -pytz==2025.2 -pyvers==0.1.0 -pyyaml==6.0.2 -pyzmq==27.0.1 -qwen-vl-utils==0.0.11 -ray==2.48.0 -referencing==0.36.2 -regex==2025.7.33 -requests==2.32.4 -rich==13.9.4 -rich-toolkit==0.14.9 -rignore==0.6.4 -rpds-py==0.27.0 -rsa==4.9.1 -ruff==0.12.8 -s3transfer==0.13.1 -safetensors==0.6.2 -scipy==1.15.3 -sentencepiece==0.2.0 -sentry-sdk==2.34.1 -setproctitle==1.3.6 -setuptools==80.9.0 -sgl-kernel==0.2.4 -sglang==0.4.9.post6 -shellingham==1.5.4 -simplejson==3.20.1 -six==1.17.0 -smart-open==7.3.0.post1 -smmap==5.0.2 -sniffio==1.3.1 -soundfile==0.13.1 -soxr==0.5.0.post1 -stack-data==0.6.3 -starlette==0.47.2 -swankit==0.2.4 -swanlab==0.6.8 -sympy==1.14.0 -tensorboard==2.20.0 -tensorboard-data-server==0.7.2 -tensordict==0.9.1 -tiktoken==0.11.0 -timm==1.0.16 -tokenizers==0.21.4 -tomli==2.2.1 -torch==2.7.1 -torch-memory-saver==0.0.8 -torchao==0.9.0 -torchaudio==2.7.1 -torchdata==0.11.0 -torchvision==0.22.1 -tqdm==4.67.1 -traitlets==5.14.3 -transformers==4.54.0 -triton==3.3.1 -typer==0.16.0 -typing-extensions==4.14.1 -typing-inspection==0.4.1 -tzdata==2025.2 -urllib3==2.5.0 -uvicorn==0.35.0 -uvloop==0.21.0 -virtualenv==20.33.1 -vllm==0.10.0 -wandb==0.21.1 -watchfiles==1.1.0 -wcwidth==0.2.13 -websockets==15.0.1 -werkzeug==3.1.3 -wrapt==1.17.2 -xformers==0.0.31 -xgrammar==0.1.21 -xxhash==3.5.0 -yarl==1.20.1 -zipp==3.23.0 -debugpy>=1.8.0 \ No newline at end of file diff --git a/scripts/converter_hf_to_mcore.py b/scripts/converter_hf_to_mcore.py index 897b0f8a..ea479bdb 100644 --- a/scripts/converter_hf_to_mcore.py +++ b/scripts/converter_hf_to_mcore.py @@ -24,18 +24,35 @@ from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from transformers import AutoConfig, AutoModelForCausalLM - from verl.models.mcore import hf_to_mcore_config from verl.utils.megatron_utils import get_model def _init_args(): parser = argparse.ArgumentParser() - parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") - parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model") - parser.add_argument("--use_cpu_initialization", action="store_true", help="Whether to use cpu initialization") + parser.add_argument( + "--hf_model_path", + type=str, + required=True, + help="The path for the huggingface model", + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="The path for the output mcore model", + ) + parser.add_argument( + "--use_cpu_initialization", + action="store_true", + help="Whether to use cpu initialization", + ) parser.add_argument("--test", action="store_true", help="Whether to test the conversion") - parser.add_argument("--trust_remote_code", action="store_true", help="Whether to trust remote code") + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether to trust remote code", + ) args = parser.parse_args() return args @@ -109,7 +126,13 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config) for layer, hf_layer in zip(model.decoder.layers, hf_model.model.layers): layer.self_attention.linear_qkv.layer_norm_weight.copy_(hf_layer.input_layernorm.weight) - q = hf_layer.self_attn.q_proj.weight.view([num_key_value_heads, head_dim * num_attention_heads // num_key_value_heads, -1]) + q = hf_layer.self_attn.q_proj.weight.view( + [ + num_key_value_heads, + head_dim * num_attention_heads // num_key_value_heads, + -1, + ] + ) k = hf_layer.self_attn.k_proj.weight.view([num_key_value_heads, head_dim, -1]) v = hf_layer.self_attn.v_proj.weight.view([num_key_value_heads, head_dim, -1]) qkv = torch.cat([q, k, v], dim=1).view(-1, hidden_dim).contiguous() @@ -138,7 +161,12 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config) if has_share_expert: layer.mlp.shared_experts.gate_weight.copy_(hf_layer.mlp.shared_expert_gate.weight) - shared_fc1_weight = torch.cat([hf_layer.mlp.shared_expert.gate_proj.weight, hf_layer.mlp.shared_expert.up_proj.weight]) + shared_fc1_weight = torch.cat( + [ + hf_layer.mlp.shared_expert.gate_proj.weight, + hf_layer.mlp.shared_expert.up_proj.weight, + ] + ) layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight) layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_expert.down_proj.weight) @@ -181,13 +209,24 @@ def safe_copy( if not hasattr(layer.mlp, "router"): layer.mlp.linear_fc1.layer_norm_weight.copy_(hf_layer.post_attention_layernorm.weight) - layer.mlp.linear_fc1.weight.copy_(torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight])) + layer.mlp.linear_fc1.weight.copy_( + torch.cat( + [ + hf_layer.mlp.gate_proj.weight, + hf_layer.mlp.up_proj.weight, + ] + ) + ) layer.mlp.linear_fc2.weight.copy_(hf_layer.mlp.down_proj.weight) else: layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight) # NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \ # recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%) - safe_copy(hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True) + safe_copy( + hf_layer.mlp.gate.e_score_correction_bias, + layer.mlp.router.expert_bias, + skip_dtype_assert=True, + ) if tfconfig.moe_grouped_gemm: for i, hf_expert in enumerate(hf_layer.mlp.experts): fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) @@ -202,7 +241,12 @@ def safe_copy( expert.linear_fc1.weight.copy_(fc1_weight) expert.linear_fc2.weight.copy_(hf_expert.down_proj.weight) layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight) - shared_fc1_weight = torch.cat([hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight]) + shared_fc1_weight = torch.cat( + [ + hf_layer.mlp.shared_experts.gate_proj.weight, + hf_layer.mlp.shared_experts.up_proj.weight, + ] + ) layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight) layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_experts.down_proj.weight) @@ -211,7 +255,13 @@ def safe_copy( model.output_layer.weight.copy_(hf_model.lm_head.weight) -def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False): +def convert_hf_to_mcore( + hf_model_path, + output_path, + use_cpu_initialization=False, + test=False, + trust_remote_code=False, +): os.makedirs(output_path, exist_ok=True) if len(os.listdir(output_path)) > 0 and not test: print(f"Output path {output_path} is not empty, skipping conversion") @@ -266,7 +316,11 @@ def megatron_model_provider(pre_process, post_process): warnings.simplefilter("ignore") # init hf model - hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code) + hf_model = AutoModelForCausalLM.from_pretrained( + hf_model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=trust_remote_code, + ) hf_state_dict = hf_model.state_dict() # load hf state dict to megatron model @@ -293,11 +347,22 @@ def megatron_model_provider(pre_process, post_process): # save megatron model if len(os.listdir(output_path)) == 0: - dist_checkpointing.save(megatron_state_dict, output_path, sharded_strategy=None, async_sharded_save=False) + dist_checkpointing.save( + megatron_state_dict, + output_path, + sharded_strategy=None, + async_sharded_save=False, + ) if test: test_conversion(megatron_model_provider, tfconfig, output_path, model) if __name__ == "__main__": args = _init_args() - convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code) + convert_hf_to_mcore( + args.hf_model_path, + args.output_path, + args.use_cpu_initialization, + args.test, + args.trust_remote_code, + ) diff --git a/scripts/deploy_model.py b/scripts/deploy_model.py new file mode 100644 index 00000000..360142e8 --- /dev/null +++ b/scripts/deploy_model.py @@ -0,0 +1,113 @@ +import argparse +import os +import sys + +import torch + +# Add current directory to path before other imports +sys.path.append(os.getcwd()) # noqa: E402 + +from loguru import logger # noqa: E402 + +from ajet.utils.cleaner import fast_kill_by_keyword_bash # noqa: E402 +from ajet.utils.smart_daemon import LaunchCommandWhenAbsent # noqa: E402 + +parser = argparse.ArgumentParser(description="deploy Hugging Face model") +parser.add_argument( + "--target", + # default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", + default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-Coder-480B-A35B-Instruct", + type=str, + help="Model path", +) +parser.add_argument( + "--alias", + default="Qwen/Qwen3-Coder-480B-A35B-Instruct", + type=str, + help="Model alias", +) +parser.add_argument( + "--kill", + default="", + type=str, + help="Keywords to kill related processes, separated by |", +) +parser.add_argument( + "--autokill", + default=False, + action="store_true", + help="Automatically kill related processes", +) +parser.add_argument("--port", default="2888", type=str, help="Port number") +args = parser.parse_args() + +if args.autokill: + args.kill = "ray|vllm|VLLM|python" + +# Handle kill-keywords argument if provided +if args.kill: + logger.info(f"Killing processes matching keywords: {args.kill}") + for keyword in args.kill.split("|"): + logger.info(f"Killing processes matching keyword: {keyword}") + killed_pids = fast_kill_by_keyword_bash(keyword) + if killed_pids: + logger.success(f"Successfully killed processes with PIDs: {killed_pids}") + else: + logger.warning(f"No processes found matching keyword: {keyword}") + + +def companion_launch(): + logger.info("Launching companion process for async LLM server...") + model_path = args.target + n_avail_gpus = torch.cuda.device_count() + tensor_parallel_size = n_avail_gpus + if tensor_parallel_size > n_avail_gpus: + logger.warning(f"Warning: tensor_parallel_size {tensor_parallel_size} is greater than available GPUs {n_avail_gpus}. Setting tensor_parallel_size to {n_avail_gpus}.") + tensor_parallel_size = n_avail_gpus + + # gpu_memory_utilization = 0.95 + # max_num_seqs = config.actor_rollout_ref.rollout.max_num_seqs + # max_model_len = config.ajet.rollout.max_model_len + # seed = config.ajet.debug.debug_vllm_seed + # vllm_port = config.ajet.debug.debug_vllm_port + vllm_port = args.port + companion = LaunchCommandWhenAbsent( + full_argument_list=[ + sys.executable, + "-m", + "vllm.entrypoints.cli.main", + "serve", + model_path, + "--tensor-parallel-size", + str(tensor_parallel_size), + "--dtype", + "auto", + # "--enforce-eager", + # "--gpu-memory-utilization", str(gpu_memory_utilization), + # "--disable-custom-all-reduce", + # "--max-num-seqs", str(max_num_seqs), + # "--max-model-len", str(max_model_len), + "--load-format", + "auto", + "--served-model-name", + args.alias, + "--enable-chunked-prefill", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enable-prefix-caching", + # "--seed", str(seed), + "--port", + vllm_port, + ], + dir="./", + tag="api_vllm_server", + ) + companion.launch( + launch_wait_time=1800, + success_std_string="Application startup complete", + env_dict={**os.environ}, + ) + + +companion_launch() diff --git a/scripts/diagnose.py b/scripts/diagnose.py index 174b1f9b..01dc0f5e 100644 --- a/scripts/diagnose.py +++ b/scripts/diagnose.py @@ -88,7 +88,12 @@ def check_pip(): def _get_current_git_commit(): try: - result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True) + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=True, + ) return result.stdout.strip() except subprocess.CalledProcessError as e: print(f"Error running git command: {e.stderr.strip()}") @@ -162,7 +167,10 @@ def check_network(args): else: import warnings - warnings.warn("Region {} do not need specific test, please refer to global sites.".format(r), stacklevel=2) + warnings.warn( + "Region {} do not need specific test, please refer to global sites.".format(r), + stacklevel=2, + ) for name, url in URLS.items(): test_connection(name, url, args.timeout) @@ -192,7 +200,10 @@ def check_cuda_versions(): import subprocess nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") - cuda_compiler_version = next((line for line in nvcc_output.splitlines() if "release" in line), None) + cuda_compiler_version = next( + (line for line in nvcc_output.splitlines() if "release" in line), + None, + ) if cuda_compiler_version: print(f"CUDA Compiler : {cuda_compiler_version.strip()}") else: @@ -219,7 +230,11 @@ def _get_gpu_info(): """ try: result = subprocess.run( - ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader,nounits"], + [ + "nvidia-smi", + "--query-gpu=gpu_name,memory.total", + "--format=csv,noheader,nounits", + ], capture_output=True, text=True, check=True, @@ -247,7 +262,11 @@ def _get_system_info(): """ cpu_memory = _get_cpu_memory() gpu_count, gpu_info = _get_gpu_info() - return {"cpu_memory": cpu_memory, "gpu_count": gpu_count, "gpu_info": gpu_info} + return { + "cpu_memory": cpu_memory, + "gpu_count": gpu_count, + "gpu_info": gpu_info, + } def check_system_info(): @@ -268,7 +287,12 @@ def parse_args(): ) choices = ["python", "pip", "verl", "system", "os", "environment"] for choice in choices: - parser.add_argument("--" + choice, default=1, type=int, help="Diagnose {}.".format(choice)) + parser.add_argument( + "--" + choice, + default=1, + type=int, + help="Diagnose {}.".format(choice), + ) parser.add_argument("--network", default=0, type=int, help="Diagnose network.") parser.add_argument("--hardware", default=0, type=int, help="Diagnose hardware.") parser.add_argument( @@ -278,7 +302,12 @@ def parse_args(): help="Additional sites in which region(s) to test. \ Specify 'cn' for example to test mirror sites in China.", ) - parser.add_argument("--timeout", default=10, type=int, help="Connection test timeout threshold, 0 to disable.") + parser.add_argument( + "--timeout", + default=10, + type=int, + help="Connection test timeout threshold, 0 to disable.", + ) args = parser.parse_args() return args diff --git a/scripts/display_dataset.py b/scripts/display_dataset.py new file mode 100644 index 00000000..f290e0a2 --- /dev/null +++ b/scripts/display_dataset.py @@ -0,0 +1,35 @@ +import argparse +import glob +import os +import time + +from beast_logger import print_list +from huggingface_hub import snapshot_download + +parser = argparse.ArgumentParser(description="download Hugging Face dataset") +parser.add_argument("--target", default="openai/gsm8k", type=str, help="HuggingFace dataset name") +args = parser.parse_args() + + +def display_dataset(dataset_name, dataset_iter, header): + from beast_logger import print_listofdict + + data = [] + for sample in dataset_iter: + s = dict(sample) + data.append(s) + print_listofdict(data[:5], header=header) + + +try: + import datasets + + dataset_iter = datasets.load_dataset(args.target, name="default", split="train") + display_dataset(args.target, dataset_iter, header="train") + dataset_iter = datasets.load_dataset(args.target, name="default", split="test") + display_dataset(args.target, dataset_iter, header="test") +except Exception as e: + print(f"Error loading dataset {args.target}: {e}") + + +# python -m scripts.download_dataset --path='./dataset/openai/gsm8k' --target='openai/gsm8k' diff --git a/scripts/docker/dockerfile b/scripts/docker/dockerfile new file mode 100644 index 00000000..b9918f5e --- /dev/null +++ b/scripts/docker/dockerfile @@ -0,0 +1,41 @@ +# Build and run the docker image with the following command: +# +# docker build -f scripts/docker/dockerfile -t ajet:latest . +# docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v :/data ajet:latest + + +FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 + +WORKDIR /workspace + +RUN chmod 1777 /tmp && apt update && apt install -y \ + build-essential \ + curl git wget vim tmux net-tools \ + python3 python3-pip python3-dev python3-venv python3-packaging \ + libomp-dev infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \ + && rm -rf /var/lib/apt/lists/* \ + && ln -sf /usr/bin/python3 /usr/bin/python \ + && ln -sf /usr/bin/pip3 /usr/bin/pip + +# For aliyun users, set pip source to aliyun mirror +# ENV PIP_INDEX_URL=http://mirrors.cloud.aliyuncs.com/pypi/simple/ +# ENV PIP_TRUSTED_HOST=mirrors.cloud.aliyuncs.com + +# set uv virtual environment path to a outside-of-workspace dir +ENV VIRTUAL_ENV=/opt/venv + +# copy the Agentscope-Tuner dir into the workspace +COPY . . + +# Install uv +RUN pip install uv + +# use uv to create a virtual environment and install dependencies +RUN uv venv /opt/venv --python=3.10 && \ + . /opt/venv/bin/activate && \ + uv pip install -e .[verl] && \ + uv pip install flash_attn==2.8.1 --no-deps --no-cache-dir + +# set entrypoint to activate the virtual environment +ENTRYPOINT ["/bin/bash", "-c", "source /opt/venv/bin/activate && exec \"$@\"", "--"] +CMD ["bash"] diff --git a/scripts/download_dataset.py b/scripts/download_dataset.py index 052788ef..565f338f 100644 --- a/scripts/download_dataset.py +++ b/scripts/download_dataset.py @@ -1,40 +1,56 @@ -import os; os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # 必须放在第一行 import argparse import glob +import os import time + +from beast_logger import print_list from huggingface_hub import snapshot_download -parser = argparse.ArgumentParser(description='download Hugging Face dataset') -parser.add_argument('--target', default='openai/gsm8k', type=str, help='要下载的数据集仓库名称') -parser.add_argument('--path', default='/mnt/data/qingxu.fu/dataset/openai/gsm8k', type=str, help='路径到下载的本地目录') +parser = argparse.ArgumentParser(description="download Hugging Face dataset") +parser.add_argument("--target", default="openai/gsm8k", type=str, help="HuggingFace dataset name") +parser.add_argument( + "--path", + default="./dataset/openai/gsm8k", + type=str, + help="Path to the local directory where the dataset will be downloaded", +) args = parser.parse_args() -snapshot_download(repo_id=args.target, repo_type="dataset", local_dir=args.path, resume_download=True) +snapshot_download( + repo_id=args.target, + repo_type="dataset", + local_dir=args.path, + resume_download=True, +) time.sleep(2) -from beast_logger import print_list + downloaded = [] -for item in glob.glob(os.path.join(args.path, '**', '*')): +for item in glob.glob(os.path.join(args.path, "**", "*")): downloaded += [os.path.abspath(item)] -print_list(downloaded, header='downloaded files') +print_list(downloaded, header="downloaded files") + def display_dataset(dataset_name, dataset_iter, header): from beast_logger import print_listofdict + data = [] for sample in dataset_iter: s = dict(sample) data.append(s) print_listofdict(data[:5], header=header) + try: import datasets - dataset_iter = datasets.load_dataset(args.path, name='main', split='train') - display_dataset(args.target, dataset_iter, header='train') - dataset_iter = datasets.load_dataset(args.path, name='main', split='test') - display_dataset(args.target, dataset_iter, header='test') + + dataset_iter = datasets.load_dataset(args.path, name="main", split="train") + display_dataset(args.target, dataset_iter, header="train") + dataset_iter = datasets.load_dataset(args.path, name="main", split="test") + display_dataset(args.target, dataset_iter, header="test") except Exception as e: print(f"Error loading dataset {args.target}: {e}") -# python -m scripts.download_dataset --path='/root/data/gsm8k' --target='openai/gsm8k' \ No newline at end of file +# python -m scripts.download_dataset --path='./dataset/openai/gsm8k' --target='openai/gsm8k' diff --git a/scripts/download_model.py b/scripts/download_model.py index 30dcdec7..45ce6008 100644 --- a/scripts/download_model.py +++ b/scripts/download_model.py @@ -1,28 +1,26 @@ - ms = input("modelscope ? (Y/n)") -if ms == "Y" or ms =="y": - - from modelscope import snapshot_download +if ms == "Y" or ms == "y": from loguru import logger - cache_dir = input("model path ( /mnt/data/model_cache/modelscope/hub/Qwen ): ").strip() + from modelscope import snapshot_download + + cache_dir = input("model path (./modelscope_cache): ").strip() if not cache_dir: - cache_dir = '/mnt/data/model_cache/modelscope/hub/Qwen' + cache_dir = "./modelscope_cache" res = snapshot_download(input("model name: ").strip(), cache_dir=cache_dir) logger.success(res) else: - import os import subprocess - os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + repo_name = input("model name: ").strip() - command = ['huggingface-cli', 'download', '--resume-download', repo_name] + command = ["huggingface-cli", "download", "--resume-download", repo_name] process = subprocess.run(command, env=os.environ, check=True) if process.returncode == 0: - print(f"成功下载 {repo_name}") + print(f"Download {repo_name} succeeded") else: - print(f"下载 {repo_name} 失败") + print(f"Download {repo_name} failed") # python -m scripts.download_model -# Qwen/Qwen3-0.6B \ No newline at end of file +# Qwen/Qwen3-0.6B diff --git a/scripts/install_vllm_sglang_mcore.sh b/scripts/install_vllm_sglang_mcore.sh deleted file mode 100644 index e8064769..00000000 --- a/scripts/install_vllm_sglang_mcore.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash - -USE_MEGATRON=${USE_MEGATRON:-1} -USE_SGLANG=${USE_SGLANG:-1} - -export MAX_JOBS=32 - -echo "1. install inference frameworks and pytorch they need" -if [ $USE_SGLANG -eq 1 ]; then - pip install "sglang[all]==0.4.6.post1" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir -fi -pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata - -echo "2. install basic packages" -pip install "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ - pytest py-spy pyext pre-commit ruff - -pip install "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - - -echo "3. install FlashAttention and FlashInfer" -# Install flash-attn-2.7.4.post1 (cxx11abi=False) -wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ - pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - -# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False) -# vllm-0.8.3 does not support flashinfer>=0.2.3 -# see https://github.com/vllm-project/vllm/pull/15777 -wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl - - -if [ $USE_MEGATRON -eq 1 ]; then - echo "4. install TransformerEngine and Megatron" - echo "Notice that TransformerEngine installation can take very long time, please be patient" - NVTE_FRAMEWORK=pytorch pip3 install --no-deps git+https://github.com/NVIDIA/TransformerEngine.git@v2.2 - pip3 install --no-deps git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.0rc3 -fi - - -echo "5. May need to fix opencv" -pip install opencv-python -pip install opencv-fixer && \ - python -c "from opencv_fixer import AutoFix; AutoFix()" - - -if [ $USE_MEGATRON -eq 1 ]; then - echo "6. Install cudnn python package (avoid being overridden)" - pip install nvidia-cudnn-cu12==9.8.0.87 -fi - -echo "Successfully installed all packages" diff --git a/scripts/model_merger.py b/scripts/model_merger.py deleted file mode 100644 index 3bd25cae..00000000 --- a/scripts/model_merger.py +++ /dev/null @@ -1,623 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. - -To merge FSDP checkpoints: -```sh -python scripts/model_merger.py merge \ - --backend fsdp \ - --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -To merge Megatron checkpoints: -```sh -python scripts/model_merger.py merge \ - --backend megatron \ - --tie-word-embedding \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -For more details, please refer to documentation: -https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model -""" - -import argparse -import os -import re -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional - -import numpy as np -import torch -from accelerate import init_empty_weights -from safetensors.torch import load_file -from torch.distributed._tensor import Placement, Shard -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForTokenClassification, - AutoModelForVision2Seq, - GenerationConfig, - PretrainedConfig, -) - -try: - # for torch 2.5+ - from torch.distributed.tensor import DTensor -except ImportError: - from torch.distributed._tensor import DTensor - -from tqdm import tqdm - -from verl.utils import hf_processor, hf_tokenizer - - -@dataclass -class ModelMergerConfig: - operation: str # 'merge' or 'test' - backend: str - local_dir: str - hf_model_config_path: str - target_dir: Optional[str] = "tmp" - hf_upload_path: Optional[str] = None - private: bool = False - test_hf_dir: Optional[str] = None - tie_word_embedding: bool = False - is_value_model: bool = False - hf_model_path: Optional[str] = None - hf_upload: bool = field(init=False) - - def __post_init__(self): - self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) - if self.operation == "test": - self.target_dir = None - self.hf_upload_path = None - self.private = False - - -class BaseModelMerger(ABC): - def __init__(self, config: ModelMergerConfig): - self.config = config - self.hf_model_config_path = config.hf_model_config_path - - if config.hf_model_path: - print("Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. ") - self.hf_model_config_path = config.hf_model_path - - self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path) - - def get_transformers_auto_model_class(self): - if "ForTokenClassification" in self.model_config.architectures[0]: - return AutoModelForTokenClassification - elif "ForCausalLM" in self.model_config.architectures[0]: - return AutoModelForCausalLM - elif "ForConditionalGeneration" in self.model_config.architectures[0]: - return AutoModelForVision2Seq - - raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") - - def patch_model_generation_config(self, model): - """ - The generation_config created from model config may be different to the pretrained model, - this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 - - This function patch the generation_config created from model config to the pretrained model. - """ - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) - except OSError: - print(f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config.") - return model - - def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - with init_empty_weights(): - model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) - model.to_empty(device="cpu") - model = self.patch_model_generation_config(model) - - print(f"Saving model to {self.config.target_dir}") - model.save_pretrained(self.config.target_dir, state_dict=state_dict) - del state_dict - del model - - processor = hf_processor(self.hf_model_config_path) - tokenizer = hf_tokenizer(self.hf_model_config_path) - if processor is not None: - print(f"Saving processor to {self.config.target_dir}") - processor.save_pretrained(self.config.target_dir) - if tokenizer is not None: - print(f"Saving tokenizer to {self.config.target_dir}") - tokenizer.save_pretrained(self.config.target_dir) - - def upload_to_huggingface(self): - from huggingface_hub import HfApi - - api = HfApi() - api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) - api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") - - @abstractmethod - def merge_and_save(self): - raise NotImplementedError("Subclasses should implement this method") - - -class FSDPModelMerger(BaseModelMerger): - def _get_world_size(self) -> int: - """Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').""" - for filename in os.listdir(self.config.local_dir): - match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) - if match: - return int(match.group(1)) - raise FileNotFoundError(f"Could not determine world size. No file matching 'model_world_size_(\d+)_rank_0.pt' found in {self.config.local_dir}") - - def _load_rank_zero_state_dict(self, world_size: int) -> dict: - return torch.load(Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", map_location="cpu", weights_only=False) - - def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: - """ - Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. - If no DTensor is found, infers a simple FSDP mesh based on world_size. - """ - pivot_key = sorted(list(state_dict.keys()))[0] - weight = state_dict[pivot_key] - - if isinstance(weight, DTensor): - # get sharding info - device_mesh = weight.device_mesh - mesh = device_mesh.mesh - mesh_dim_names = device_mesh.mesh_dim_names - else: - # for non-DTensor - mesh = np.array([world_size], dtype=np.int64) - mesh_dim_names = ("fsdp",) - - return mesh, mesh_dim_names - - def _calculate_shard_configuration(self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]) -> tuple[int, tuple[int, ...]]: - """Calculates the total number of shards and the shape of the device mesh.""" - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" - - if "tp" in mesh_dim_names: - # TODO: "tp" is not supported yet due to the above assert - total_shards = mesh.shape[-1] * mesh.shape[-2] - mesh_shape = (mesh.shape[-2], mesh.shape[-1]) - else: - total_shards = mesh.shape[-1] - mesh_shape = (mesh.shape[-1],) - - return total_shards, mesh_shape - - def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: - """Merges a list of tensors based on their DTensor placement""" - if placement.is_replicate(): - return tensors[0] - elif placement.is_partial(): - raise NotImplementedError("Partial placement is not supported yet") - elif placement.is_shard(): - return torch.cat(tensors, dim=placement.dim).contiguous() - - raise NotImplementedError(f"Unsupported placement: {placement}") - - def _load_and_merge_state_dicts(self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]) -> dict[str, torch.Tensor]: - model_state_dict_lst = [None] * total_shards - - def process_one_shard(rank: int, model_state_dict_lst: list): - model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - model_state_dict_lst[rank] = state_dict - return state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] - for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): - future.result() - - # Merge state dicts from all shards - state_dict = {} - param_placements: dict[str, list] = {} - - for key in set(model_state_dict_lst[0].keys()): - state_dict[key] = [] - for model_state_shard in model_state_dict_lst: - # add tensor shard in order of rank to state_dict[key] - tensor = model_state_shard.pop(key) - if isinstance(tensor, DTensor): - state_dict[key].append(tensor._local_tensor.bfloat16()) - - placements = tuple(tensor.placements) - # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] in ("dp", "ddp"): - placements = placements[1:] - - if key not in param_placements: - param_placements[key] = placements - else: - assert param_placements[key] == placements - else: - state_dict[key].append(tensor.bfloat16()) - - del model_state_dict_lst - - # Merge tensors - for key in sorted(state_dict): - if not isinstance(state_dict[key], list): - print(f"No need to merge key {key}") - continue - if key in param_placements: - # merge shards - placements: tuple[Shard] = param_placements[key] - if len(mesh_shape) == 1: - # 1-D list, FSDP without TP - assert len(placements) == 1 - shards = state_dict[key] - state_dict[key] = self._merge_by_placement(shards, placements[0]) - else: - # 2-D list, FSDP + TP - raise NotImplementedError("FSDP + TP is not supported yet") - else: - state_dict[key] = torch.cat(state_dict[key], dim=0) - - return state_dict - - def merge_and_save(self): - world_size = self._get_world_size() - rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - - mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) - print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - - total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) - print(f"Processing model shards with {total_shards} {mesh_shape} in total") - - merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - - hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) - hf_state_dict = hf_model.state_dict() - del hf_model - - hf_model_keys = set(hf_state_dict.keys()) - collected_keys = set(state_dict.keys()) - - missing_keys = hf_model_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" - - extra_keys = collected_keys - hf_model_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" - - for key in hf_model_keys: - hf_shape = hf_state_dict[key].shape - collected_shape = state_dict[key].shape - assert hf_shape == collected_shape, f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - - hf_dtype = hf_state_dict[key].dtype - collected_dtype = state_dict[key].dtype - assert hf_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - - torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) - - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") - - -class MegatronModelMerger(BaseModelMerger): - def __init__(self, config: ModelMergerConfig): - from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path - - config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir) - super().__init__(config) - - def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: - match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir) - assert match, f"Invalid sharded dir {sharded_dir}" - tp_rank = int(match.group(1)) - pp_rank = int(match.group(2)) - return tp_rank, pp_rank - - def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]: - """ - Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). - Determines TP and PP sizes from directory names. - """ - tp_size = 0 - pp_size = 0 - sharded_dirs = sorted(os.listdir(model_path)) - for sharded_dir in sharded_dirs: - assert "model.pt" in os.listdir(Path(model_path) / sharded_dir), f"model.pt not found in {sharded_dir}" - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - tp_size = max(tp_size, tp_rank + 1) - pp_size = max(pp_size, pp_rank + 1) - return sharded_dirs, tp_size, pp_size - - def _merge_across_tp(self, key: str, tp_data: list[torch.Tensor], config: PretrainedConfig, tp_size: int, is_value_model: bool = False) -> torch.Tensor | list[torch.Tensor]: - if "linear_fc1.weight" in key: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in tp_data: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - return [gate, up] - - elif "self_attention.linear_qkv." in key and "layer_norm" not in key: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - assert config.num_attention_heads % config.num_key_value_heads == 0 - num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0 - kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - - for infer_param in tp_data: - num_query_groups_per_partition = config.num_key_value_heads // tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - return [q, k, v] - - elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and is_value_model: - return tp_data[0] - else: - dim = 0 - if "linear_fc2.weight" in key or "self_attention.linear_proj" in key: - dim = 1 - return torch.cat(tp_data, dim=dim) - - def _load_state_dicts(self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int) -> list[list[dict]]: - model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)] - - def _process_one_megatron_shard(sharded_dir: str): - model_file_path = Path(model_ckpt_path) / sharded_dir / "model.pt" - state_dict = torch.load(model_file_path, map_location="cpu", weights_only=False) - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - model_state_dict_lst[pp_rank][tp_rank] = state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs] - for future in tqdm(futures, desc=f"Loading {len(sharded_dirs)} Megatron shards", total=len(sharded_dirs)): - future.result() - - return model_state_dict_lst - - def _merge_state_dicts(self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int) -> dict[str, torch.Tensor]: - state_dict = {} - vpp_size = len(model_state_dict_lst[0][0]) - layers_cum = 0 - - for vpp_rank in range(vpp_size): - for pp_rank in range(pp_size): - layers_handled = 0 - keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() - for key in keys: - if "extra_state" in key: - continue - if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") - continue - - new_key = key - if "decoder.layers." in key: - local_layer_no = int(key.split(".")[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list = key.split(".") - new_key_list[2] = str(global_layer_no) - new_key = ".".join(new_key_list) - - tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] - merged = self._merge_across_tp(new_key, tp_data, self.model_config, tp_size, self.config.is_value_model) - - if not isinstance(merged, list): - state_dict[new_key] = merged - elif len(merged) == 3: - # split qkv - for n, d in zip(["q", "k", "v"], merged): - state_dict[new_key.replace("linear_qkv", f"linear_{n}")] = d - elif len(merged) == 2: - # split gate up - state_dict[new_key.replace("linear_fc1", "gate_proj")] = merged[0] - state_dict[new_key.replace("linear_fc1", "up_proj")] = merged[1] - - layers_cum += layers_handled + 1 # zero based - - return state_dict - - def merge_and_save(self): - from verl.utils.megatron_utils import get_model_checkpoint_path - - model_ckpt_path = get_model_checkpoint_path(self.config.local_dir) - sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path) - print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}") - - model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size) - merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size) - del model_state_dict_lst - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - """ - Compares the merged Megatron state_dict against a reference safetensors model. - Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. - """ - ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") - - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), - ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), - ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ("self_attention.linear_q", "self_attn.q_proj"), - ("self_attention.linear_k", "self_attn.k_proj"), - ("self_attention.linear_v", "self_attn.v_proj"), - ] - - for original_name, loaded_weight in state_dict.items(): - name = self._replace_name(original_name, params_mapping) - if not name or name.endswith(".bias") and name not in ref_state_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embedding and "lm_head.weight" in name: - continue - if name not in ref_state_dict: - raise RuntimeError(f"key: {name} not exist in state_dict") - param = ref_state_dict[name] - assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2) - - def _replace_name(self, megatron_name: str, name_mapping: list[tuple[str, str]]) -> str: - for m_name, v_name in name_mapping: - if m_name not in megatron_name: - continue - if "layers" in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace("decoder", "model") - megatron_name_list = megatron_name.split(".") - if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: - param_name_list = megatron_name_list[:3] - param_name_list.append(v_name) - param_name = ".".join(param_name_list) - else: - param_name_list = megatron_name_list[:3] - weight_or_bias = megatron_name_list[-1] - param_name_list.append(v_name) - param_name_list.append(weight_or_bias) - param_name = ".".join(param_name_list) - return param_name - else: - param_name = megatron_name.replace(m_name, v_name) - return param_name - return None # Return None if no mapping found - - -def main(): - parser = argparse.ArgumentParser(description="verl model merger") - subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") - - base_op_parser = argparse.ArgumentParser(add_help=False) - base_op_parser.add_argument("--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model") - base_op_parser.add_argument("--local_dir", type=str, required=True, help="Path to the saved model checkpoints") - base_op_parser.add_argument("--hf_model_path", type=str, default=None, help="(Deprecated) Path to the original Hugging Face model for config.") - base_op_parser.add_argument("--tie-word-embedding", action="store_true", help="Whether to tie word embedding weights (currently only Megatron supported)") - base_op_parser.add_argument("--is-value-model", action="store_true", help="Whether the model is a value model (currently only Megatron supported)") - - merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") - merge_parser.add_argument("--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model") - merge_parser.add_argument("--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model") - merge_parser.add_argument("--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository") - - test_parser = subparsers.add_parser("test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model") - test_parser.add_argument("--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing") - - args = parser.parse_args() - - common_config_args = { - "operation": args.operation, - "backend": args.backend, - "tie_word_embedding": args.tie_word_embedding, - "is_value_model": args.is_value_model, - "local_dir": args.local_dir, - "hf_model_path": args.hf_model_path, - "hf_model_config_path": args.local_dir, - } - - if args.operation == "merge": - config = ModelMergerConfig( - **common_config_args, - target_dir=args.target_dir, - hf_upload_path=args.hf_upload_path, - private=args.private, - test_hf_dir=None, - ) - os.makedirs(config.target_dir, exist_ok=True) - elif args.operation == "test": - config = ModelMergerConfig( - **common_config_args, - test_hf_dir=args.test_hf_dir, - # the following args are not used by test operation - target_dir=None, - hf_upload_path=None, - private=False, - ) - else: - raise NotImplementedError(f"Unknown operation: {args.operation}") - - if config.backend == "fsdp": - merger = FSDPModelMerger(config) - elif config.backend == "megatron": - merger = MegatronModelMerger(config) - else: - raise NotImplementedError(f"Unknown backend: {config.backend}") - - merger.merge_and_save() - - -if __name__ == "__main__": - main() diff --git a/scripts/tar_and_copy.sh b/scripts/tar_and_copy.sh deleted file mode 100755 index d898b57e..00000000 --- a/scripts/tar_and_copy.sh +++ /dev/null @@ -1,4 +0,0 @@ -cd ../ -gtar -cvzf BeyondAgent.tar.gz BeyondAgent -#scp BeyondAgent.tar.gz jinli.yl@11.160.132.45:/home/jinli.yl/jinli_mnt2/workspace/ -scp -P 1016 BeyondAgent.tar.gz root@8.130.105.202:/mnt/data/jinli.yl/ \ No newline at end of file diff --git a/scripts/test_dashscope_api.py b/scripts/test_dashscope_api.py new file mode 100644 index 00000000..dcf26c2b --- /dev/null +++ b/scripts/test_dashscope_api.py @@ -0,0 +1,25 @@ +import asyncio +from ajet.utils.robust_dashscope import RobustDashScopeChatModel + + +async def test_dashscope_api(): + """Test the RobustDashScopeChatModel by making a simple API call.""" + try: + llm = RobustDashScopeChatModel("qwen-plus", stream=False) + + # Sample messages for a basic conversation + messages = [{"role": "user", "content": "Hello! Can you tell me a short joke?"}] + + # Call the model + response = await llm(messages) + + # Print and verify the response + print(response) + + except Exception as e: + print(f"Test failed with error: {e}") + raise + + +if __name__ == "__main__": + asyncio.run(test_dashscope_api()) diff --git a/scripts/untar.sh b/scripts/untar.sh deleted file mode 100644 index 4007e209..00000000 --- a/scripts/untar.sh +++ /dev/null @@ -1 +0,0 @@ -rm -rf BeyondAgent && tar -xzvf BeyondAgent.tar.gz \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 58f29f5a..00000000 --- a/setup.py +++ /dev/null @@ -1,7 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name="astune", # Change this to your package name - version="0.1.0", - packages=find_packages(), -) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bench/README.md b/tests/bench/README.md new file mode 100644 index 00000000..a849d3ca --- /dev/null +++ b/tests/bench/README.md @@ -0,0 +1,27 @@ +Note: `tests/bench` source code is for test robot only, therefore `yaml` configurations will contain dataset files stored in benchmarking-docker-image. + +- To get these dataset files, please refer to `tutorial/*`. + +- Benchmarking-docker-image for test-robot will be released in 2026 Feb. + +## Cheat Sheet + +```python +# prepare model path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen +# prepare dataset path +# prepare swanlab api + +source .venv/bin/activate + +python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py +python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py +python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py +python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py +python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py + +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py::TestBenchmarkMath::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py::TestBenchmarkAppworld::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py::TestBenchmarkCountdown::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py::TestBenchmarkLearnToAsk::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py::TestBenchmarkFrozenLake::test_01_begin_verl +``` diff --git a/tests/bench/benchmark_appworld/benchmark_appworld.py b/tests/bench/benchmark_appworld/benchmark_appworld.py new file mode 100644 index 00000000..70b440bf --- /dev/null +++ b/tests/bench/benchmark_appworld/benchmark_appworld.py @@ -0,0 +1,29 @@ +# flake8: noqa +import time + +from ajet.utils.testing_utils import BenchmarkProbe, singleton + + +@singleton +class TestProbe(BenchmarkProbe): + def __init__(self): + # fmt: off + self.expected_train_time = 3600 * 48 # 24 hours + self.begin_time = time.time() + self.reward_array = [] + self.reward_expectation_avg_window = 5 + self.reward_expectation = { + # step : expected local average reward range + # step : [low, high ] + 5 : [0.0, 99999.0], + 10 : [0.0, 99999.0], + 20 : [0.0, 99999.0], + 30 : [0.0, 99999.0], + } + # fmt: on + self.probe_list = ["reward_probe"] + self.reward_key = "reward_for_test_robot" + self.probe_key = "reward_probe" + + def __call__(self, key, log_dict): + return super().__call__(key, log_dict) diff --git a/tests/bench/benchmark_appworld/benchmark_appworld.yaml b/tests/bench/benchmark_appworld/benchmark_appworld.yaml new file mode 100644 index 00000000..f83e91f0 --- /dev/null +++ b/tests/bench/benchmark_appworld/benchmark_appworld.yaml @@ -0,0 +1,73 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_appworld + experiment_name: "read_yaml_name" + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + + task_reader: + type: env_service # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: dev + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # ✨✨✨✨ 编写并选择Agent + force_disable_toolcalls: True + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + temperature: 0.9 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 4096 + max_model_len: 18000 + multi_turn: + max_sample_per_task: 25 + max_steps: 25 + n_vllm_engine: 2 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 64 + max_prompt_length: 3000 + max_response_length: 15000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + nnodes: 1 + n_gpus_per_node: 8 + + execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT + execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml new file mode 100644 index 00000000..4ae12f17 --- /dev/null +++ b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml @@ -0,0 +1,78 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_appworld + experiment_name: "read_yaml_name" + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + + task_reader: + type: env_service # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: dev + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # ✨✨✨✨ 编写并选择Agent + force_disable_toolcalls: True + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + temperature: 0.9 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 4 + max_num_seqs: 10 + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 4096 + max_model_len: 18000 + multi_turn: + max_sample_per_task: 30 + max_steps: 30 + n_vllm_engine: 2 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 64 + max_prompt_length: 3000 + max_response_length: 15000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + nnodes: 2 + n_gpus_per_node: 8 + ulysses_sequence_parallel_size: 2 + + execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT + execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # + +trinity: + synchronizer: + sync_offset: 0 + sync_method: nccl + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml new file mode 100644 index 00000000..e3175d19 --- /dev/null +++ b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml @@ -0,0 +1,71 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_appworld + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + + task_reader: + type: env_service # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: dev + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # ✨✨✨✨ 编写并选择Agent + force_disable_toolcalls: True + user_workflow: tutorial.example_appworld.appworld_oai_sdk->ExampleAgentScopeWorkflow + temperature: 0.9 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 4096 + max_model_len: 18000 + multi_turn: + max_sample_per_task: 25 + max_steps: 25 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 64 + max_prompt_length: 3000 + max_response_length: 15000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + nnodes: 1 + n_gpus_per_node: 8 + + execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT + execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_appworld/execute_benchmark_appworld.py b/tests/bench/benchmark_appworld/execute_benchmark_appworld.py new file mode 100644 index 00000000..77b3dc8b --- /dev/null +++ b/tests/bench/benchmark_appworld/execute_benchmark_appworld.py @@ -0,0 +1,100 @@ +import os +import subprocess +import time + +from loguru import logger + +from tests.bench.benchmark_base import BenchmarkTestCase + + +class TestBenchmarkAppworld(BenchmarkTestCase): + def test_01_begin_verl(self): + # get probe target, so as to get timeout settings + BACKBONE = "verl" + TEST_TARGET = "tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml" + PROBE_TARGET = "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" + # tests/bench/benchmark_appworld/benchmark_appworld.py + # tests/bench/benchmark_appworld/benchmark_appworld.yaml + TARGET_NAME = f"benchmark_appworld_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + multi_nodes = False + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + extra_cmd_args=["--with-appworld"], + pre_launch=self.install_appworld, + use_ray_cluster=multi_nodes, + enable_ray_for_trinity=not multi_nodes, + ) + + def test_02_begin_trinity(self): + # get probe target, so as to get timeout settings + BACKBONE = "trinity" + TEST_TARGET = "tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml" + PROBE_TARGET = "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" + TARGET_NAME = f"benchmark_appworld_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + multi_nodes = True + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + extra_cmd_args=["--with-appworld"], + pre_launch=self.install_appworld, + use_ray_cluster=multi_nodes, + enable_ray_for_trinity=not multi_nodes, + ) + + def clear_system_processes(self): + # kill all python + ray + vllm processes + from ajet.utils.cleaner import fast_kill_by_keyword_bash + + total_seconds = 15 + for i in range(total_seconds): + logger.warning(f"Warning: To install Appworld, we have kill all `python / VLLM / vllm / ray` processes in your system. IF this is NOT acceptable, TERMINATE NOW! Execute in {total_seconds - i} seconds...") + time.sleep(1) + + kill = "ray|vllm|VLLM|python" + for keyword in kill.split("|"): + logger.info(f"Killing processes matching keyword: {keyword}") + killed_pids = fast_kill_by_keyword_bash(keyword) + if killed_pids: + logger.success(f"Successfully killed processes with PIDs: {killed_pids}") + else: + logger.warning(f"No processes found matching keyword: {keyword}") + + def install_appworld(self): + # run: + # `rm -rf /tmp/pack_all_in_one & wget https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v2.tar.gz && tar -xzf ./appworld_pack_v2.tar.gz -C /tmp` + self.clear_system_processes() + import shutil + + if os.path.exists("/tmp/pack_all_in_one"): + shutil.rmtree("/tmp/pack_all_in_one") + if os.path.exists("./appworld_pack_v2.tar.gz"): + os.remove("./appworld_pack_v2.tar.gz") + subprocess.run( + [ + "wget", + "https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/appworld_pack_v2.tar.gz", + ] + ) + subprocess.run( + [ + "tar", + "-xzf", + "./appworld_pack_v2.tar.gz", + "-C", + "/tmp", + ] + ) + # write + os.environ["APPWORLD_PATH"] = "/tmp/pack_all_in_one" + os.environ["APPWORLD_SCRIPT"] = "bash EnvService/env_sandbox/appworld.sh" diff --git a/tests/bench/benchmark_base.py b/tests/bench/benchmark_base.py new file mode 100644 index 00000000..7b6815f3 --- /dev/null +++ b/tests/bench/benchmark_base.py @@ -0,0 +1,108 @@ +import os +import unittest +from pathlib import Path +from typing import Callable, List, Optional + +from beast_logger import print_dict + +from ajet.utils.dynamic_import import dynamic_import +from ajet.utils.smart_daemon import LaunchCommandWhenAbsent +from ajet.utils.testing_utils import ( + populate_test_env_metadata, + send_test_result, +) + + +class BenchmarkTestCase(unittest.TestCase): + def execute_benchmark( + self, + *, + backbone: str, + test_target: str, + probe_target: str, + target_name: str, + python_executable: str, + extra_cmd_args: Optional[List[str]] = None, + pre_launch: Optional[Callable[[], None]] = None, + use_ray_cluster: bool = False, + enable_ray_for_trinity: bool = True, + ) -> None: + """Run a benchmark with shared boilerplate for setup and process management.""" + workspace_dir = Path(__file__).resolve().parents[2] + + git_hash, req_txt = populate_test_env_metadata(str(workspace_dir)) + os.environ["AJET_GIT_HASH"] = git_hash + os.environ["AJET_REQ_TXT"] = req_txt + os.environ["AJET_BENCHMARK_NAME"] = target_name + + if pre_launch: + pre_launch() + + send_test_result( + git_hash=git_hash, + target=target_name, + status="running", + status_detail="", + req_txt=req_txt, + append_log="", + data_dashboard_url="", + timeout=10.0, + ) + + timeout_seconds = dynamic_import(probe_target)().expected_train_time + 600 + + cmd = [ + python_executable, + "-m", + "ajet.launcher", + "--conf", + test_target, + "--backbone", + backbone, + "--autokill", + ] + if extra_cmd_args: + cmd += extra_cmd_args + if use_ray_cluster: + cmd += ["--with-ray-cluster"] + elif enable_ray_for_trinity and backbone == "trinity": + cmd += ["--with-ray"] + + companion = LaunchCommandWhenAbsent( + full_argument_list=cmd, + dir=str(workspace_dir), + tag=target_name, + ) + + test_successful = False + terminate_str = companion.launch( + launch_wait_time=timeout_seconds, + success_std_string=[ + "TestSuccessException", + "TestFailException", + "You can force stop the `Trainer` process by pressing Ctrl+C", + "torch.OutOfMemoryError: CUDA out of memory", + ], + env_dict=os.environ, + force_restart=True, + ) + test_successful = True + companion.kill_self() + if terminate_str == "TestSuccessException": + test_successful = True + elif terminate_str == "TestFailException": + test_successful = False + raise RuntimeError("Benchmark test failed during execution.") + elif terminate_str == "You can force stop the `Trainer` process by pressing Ctrl+C": + test_successful = False + raise RuntimeError("Unknown trinity exception.") + else: + test_successful = False + raise RuntimeError(f"Benchmark test timed out or crashed. {test_successful}") + + print_dict( + { + "TestTarget": test_target, + "TestSuccessful": test_successful, + } + ) diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.py b/tests/bench/benchmark_countdown/benchmark_countdown.py new file mode 100644 index 00000000..fedb48f7 --- /dev/null +++ b/tests/bench/benchmark_countdown/benchmark_countdown.py @@ -0,0 +1,30 @@ +# flake8: noqa +import time + +from ajet.utils.testing_utils import BenchmarkProbe, singleton + + +@singleton +class TestProbe(BenchmarkProbe): + def __init__(self): + # fmt: off + self.expected_train_time = 3600 * 12 # 12 hours budget for countdown benchmark + self.begin_time = time.time() + self.reward_array = [] + self.reward_expectation_avg_window = 30 + self.reward_expectation = { + # step : expected local average reward range + # step : [low, high ] + 30 : [0.30, 99999.0], + 60 : [0.40, 99999.0], + 90 : [0.45, 99999.0], + 120 : [0.50, 99999.0], + 150 : [0.55, 99999.0], + } + # fmt: on + self.probe_list = ["reward_probe"] + self.reward_key = "reward_for_test_robot" + self.probe_key = "reward_probe" + + def __call__(self, key, log_dict): + return super().__call__(key, log_dict) diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.yaml b/tests/bench/benchmark_countdown/benchmark_countdown.yaml new file mode 100644 index 00000000..fcd07f35 --- /dev/null +++ b/tests/bench/benchmark_countdown/benchmark_countdown.yaml @@ -0,0 +1,139 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: benchmarking + + model: + # ✨✨✨✨ which model should be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct + + data: + # max number of tokens for prompt + max_prompt_length: 1024 + # max number of tokens for response + max_response_length: 4096 + # how many tasks per training batch + train_batch_size: 32 + # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) + + + rollout: + + # ✨✨✨✨ the path to the workflow class + user_workflow: tutorial.example_countdown.countdown->ExampleCountdownLearn + + # whether or not to disable all tool calls + force_disable_toolcalls: True + + # maximum number of parallel environments / simulate workers + max_env_worker: 128 + + # step reward gamma (experimental, do not change) + gamma: 1.0 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: + - "nonsense" + # send signal to terminate context tracing when LLM is losing control + agent_madness_termination: True # terminate_after_gone_mad + # punish the LLM when it is detected as lost control + agent_madness_reward: -1.0 + + # max response length in one turn + max_response_length_in_one_turn: 4096 + + # max token length allowed for the model during rollout + max_model_len: 5120 + + multi_turn: + # how many samples should be collected for each task run + max_sample_per_task: 30 + # limit the maximum steps for each task + max_steps: 30 + # the expected steps for each task, used to calculate the training batch size for trinity + expected_steps: 1 + + # TP size for rollout engine + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 2 + + # how many sequences are allowed to be processed in parallel by each vllm engine + max_num_seqs: 10 + + # the usage of infer engine, options: (vllm, sglang) + name: vllm + + # how many times a task should be repeated + num_repeat: 4 + + # rollout kwargs + temperature: 0.9 + top_p: 1.0 + + # validation kwargs + val_kwargs: + temperature: 0.0 + top_k: -1 + top_p: 1.0 + do_sample: False + num_repeat: 1 + + + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` or `data_generation` + huggingface_dat_repo: + dataset_path: "/mnt/data_cpfs/model_cache/modelscope/dataset/Countdown-Tasks" + training_split: "train" + validation_split: "test" + + + task_judge: + judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' + # ✨✨✨✨ when `judge_type == customized_protocol` + judge_protocol: tutorial.example_countdown.countdown_answer_as_judge->CountdownAnswerAsJudge + + + # trainer common configurations + trainer_common: + val_before_train: False + val_pass_n: 4 + save_freq: 50 + test_freq: 20 + total_epochs: 5 + nnodes: 1 + n_gpus_per_node: 8 + logger: swanlab + algorithm: + adv_estimator: grpo + use_kl_in_reward: False + mini_batch_num: 1 + fsdp_config: + param_offload: True + optimizer_offload: True + optim: + lr: 1e-6 + use_kl_loss: True + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl + ulysses_sequence_parallel_size: 1 + + + # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_test: True # FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_testing_lambda: "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" # FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_countdown/execute_benchmark_countdown.py b/tests/bench/benchmark_countdown/execute_benchmark_countdown.py new file mode 100644 index 00000000..59511216 --- /dev/null +++ b/tests/bench/benchmark_countdown/execute_benchmark_countdown.py @@ -0,0 +1,36 @@ +import os +import unittest + +from tests.bench.benchmark_base import BenchmarkTestCase + + +class TestBenchmarkCountdown(BenchmarkTestCase, unittest.TestCase): + def test_01_begin_verl(self): + BACKBONE = "verl" + TEST_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.yaml" + PROBE_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" + TARGET_NAME = f"benchmark_countdown_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) + + def test_02_begin_trinity(self): + BACKBONE = "trinity" + TEST_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.yaml" + PROBE_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" + TARGET_NAME = f"benchmark_countdown_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) diff --git a/tests/bench/benchmark_frozenlake/benchmark_frozenlake.py b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.py new file mode 100644 index 00000000..7eadcf41 --- /dev/null +++ b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.py @@ -0,0 +1,29 @@ +# flake8: noqa +import time + +from ajet.utils.testing_utils import BenchmarkProbe, singleton + + +@singleton +class TestProbe(BenchmarkProbe): + def __init__(self): + # fmt: off + self.expected_train_time = 3600 * 12 # 12 hours budget for frozenlake easy benchmark + self.begin_time = time.time() + self.reward_array = [] + self.reward_expectation_avg_window = 20 + self.reward_expectation = { + # step : expected local average reward range + # step : [low, high ] + 50 : [0.0, 99999.0], + 100 : [0.0, 99999.0], + 150 : [0.0, 99999.0], + 200 : [0.0, 99999.0], + } + # fmt: on + self.probe_list = ["reward_probe"] + self.reward_key = "reward_for_test_robot" + self.probe_key = "reward_probe" + + def __call__(self, key, log_dict): + return super().__call__(key, log_dict) diff --git a/tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml new file mode 100644 index 00000000..1e08d03c --- /dev/null +++ b/tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml @@ -0,0 +1,96 @@ +# ------------------ main config ------------------ +ajet: + project_name: benchmarking + + task_reader: + type: random_dummy # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` or `random_dummy` + + task_judge: + # ✨✨✨✨ select evaluation function + judge_protocol: null + + model: + # ✨✨✨✨ select model to be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct + + + rollout: + # the path to the workflow class + user_workflow: tutorial.example_frozenlake.frozenlake->FrozenLakeWorkflow + force_disable_toolcalls: True + + temperature: 0.9 + + max_env_worker: 512 + + num_repeat: 4 + + agent_madness_reward: 0.0 + + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 4 + + mode: async + + max_num_seqs: 40 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: [] + + max_response_length_in_one_turn: 512 + + max_model_len: 22000 + + multi_turn: + # how many samples should be collected for each task run + max_sample_per_task: 20 + # limit the maximum steps for each task + max_steps: 20 + # the expected steps for each task, used to calculate the training batch size for trinity + expected_steps: 1 + + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 128 + max_prompt_length: 4000 + max_response_length: 18000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + total_training_steps: 25 + nnodes: 1 + n_gpus_per_node: 8 + logger: swanlab + + + execute_test: True + execute_testing_lambda: "tests/bench/benchmark_frozenlake/benchmark_frozenlake.py->TestProbe" + + +frozen_lake: + frozen_lake_size: 4 + is_slippery: False + + + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ do not edit ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py b/tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py new file mode 100644 index 00000000..bed5032e --- /dev/null +++ b/tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py @@ -0,0 +1,35 @@ +import os + +from tests.bench.benchmark_base import BenchmarkTestCase + + +class TestBenchmarkFrozenLake(BenchmarkTestCase): + def test_01_begin_verl(self): + BACKBONE = "verl" + TEST_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml" + PROBE_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.py->TestProbe" + TARGET_NAME = f"benchmark_frozenlake_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) + + def test_02_begin_trinity(self): + BACKBONE = "trinity" + TEST_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml" + PROBE_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.py->TestProbe" + TARGET_NAME = f"benchmark_frozenlake_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py new file mode 100644 index 00000000..d3d18e4c --- /dev/null +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.py @@ -0,0 +1,38 @@ +# flake8: noqa +import time + +from ajet.utils.testing_utils import BenchmarkProbe, singleton + +# trinity b.b. expectation +# [TestProbe] Step 50: local average reward over last self.reward_expectation_avg_window steps: 2.6618, expected range: [0.0, 99999.0] +# [TestProbe] Step 100: local average reward over last self.reward_expectation_avg_window steps: 2.8733, expected range: [0.0, 99999.0] +# [TestProbe] Step 200: local average reward over last self.reward_expectation_avg_window steps: 2.9725, expected range: [0.0, 99999.0] + +# verl b.b. expectation +# [TestProbe] Step 50: local average reward over last self.reward_expectation_avg_window steps: 3.1562, expected range: [0.0, 99999.0] +# [TestProbe] Step 100: local average reward over last self.reward_expectation_avg_window steps: 3.4732, expected range: [0.0, 99999.0] +# [TestProbe] Step 200: local average reward over last self.reward_expectation_avg_window steps: 3.5645, expected range: [0.0, 99999.0] + + +@singleton +class TestProbe(BenchmarkProbe): + def __init__(self): + # fmt: off + self.expected_train_time = 3600 * 24 # 24 hours + self.begin_time = time.time() + self.reward_array = [] + self.reward_expectation_avg_window = 20 + self.reward_expectation = { + # step : expected local average reward range + # step : [low, high ] + 50 : [2.5, 99999.0], + 100 : [2.7, 99999.0], + 200 : [2.9, 99999.0], + } + # fmt: on + self.probe_list = ["reward_probe"] + self.reward_key = "reward_for_test_robot" + self.probe_key = "reward_probe" + + def __call__(self, key, log_dict): + return super().__call__(key, log_dict) diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml new file mode 100644 index 00000000..dd3b6a18 --- /dev/null +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml @@ -0,0 +1,72 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_learn2ask_enhancedreward + task_reader: + type: jsonl_dataset_file + jsonl_dataset_file: + training: + file_path: /mnt/data_cpfs/model_cache/modelscope/dataset/realmedconv/train.jsonl + validation: + file_path: /mnt/data_cpfs/model_cache/modelscope/dataset/realmedconv/test.jsonl + + model: + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + rollout: + user_workflow: "tutorial.example_learn2ask.learn2ask->ExampleLearn2Ask" + force_disable_toolcalls: True + temperature: 1.0 + max_env_worker: 64 + num_repeat: 6 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + multi_turn: + max_sample_per_task: 2 + + compute_madness_checklist: + - "nonsense" + agent_madness_reward: 0.0 + + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + n_vllm_engine: 2 + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 100 + total_epochs: 100 + logger: swanlab + + + execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT + execute_testing_lambda: "tests/bench/benchmark_learn2ask/benchmark_learn2ask.py->TestProbe" # DO NOT EDIT, THIS IS FOR TEST ROBOT + + +trinity: + synchronizer: + sync_offset: 1 + sync_method: nccl + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py b/tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py new file mode 100644 index 00000000..39059b1d --- /dev/null +++ b/tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py @@ -0,0 +1,37 @@ +import os + +from tests.bench.benchmark_base import BenchmarkTestCase + + +class TestBenchmarkLearnToAsk(BenchmarkTestCase): + def test_01_begin_verl(self): + # get probe target, so as to get timeout settings + BACKBONE = "verl" + TEST_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml" + PROBE_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.py->TestProbe" + TARGET_NAME = f"benchmark_learn2ask_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) + + def test_02_begin_trinity(self): + # get probe target, so as to get timeout settings + BACKBONE = "trinity" + TEST_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml" + PROBE_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.py->TestProbe" + TARGET_NAME = f"benchmark_learn2ask_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) diff --git a/tests/bench/benchmark_math/benchmark_math.py b/tests/bench/benchmark_math/benchmark_math.py new file mode 100644 index 00000000..9d8397ca --- /dev/null +++ b/tests/bench/benchmark_math/benchmark_math.py @@ -0,0 +1,30 @@ +# flake8: noqa +import os +import time + +from ajet.utils.testing_utils import BenchmarkProbe, singleton + + +@singleton +class TestProbe(BenchmarkProbe): + def __init__(self): + # fmt: off + self.expected_train_time = 3600 * 24 # 24 hours + self.begin_time = time.time() + self.reward_array = [] + self.reward_expectation_avg_window = 5 + self.reward_expectation = { + # step : expected local average reward range + # step : [low, high ] + 5 : [0.10, 99999.0], + 10 : [0.45, 99999.0], + 20 : [0.68, 99999.0], + 30 : [0.85, 99999.0], + } + # fmt: on + self.probe_list = ["reward_probe"] + self.reward_key = "reward_for_test_robot" + self.probe_key = "reward_probe" + + def __call__(self, key, log_dict): + return super().__call__(key, log_dict) diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml new file mode 100644 index 00000000..0f24a0de --- /dev/null +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -0,0 +1,74 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: benchmarking + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: '/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main' + training_split: "train" + validation_split: "test" + + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + rollout: + user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent + temperature: 1.0 + max_env_worker: 64 + max_num_seqs: 256 + num_repeat: 6 + agent_madness_reward: 0.0 + tensor_model_parallel_size: 1 + multi_turn: + max_sample_per_task: 2 + compute_madness_checklist: + - "nonsense" + - "wrong_toolcall" + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + n_vllm_engine: 2 + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 999999 + total_epochs: 100 + logger: swanlab + + + execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT + execute_testing_lambda: "tests/bench/benchmark_math/benchmark_math.py->TestProbe" # + + +trinity: + synchronizer: + sync_offset: 1 + sync_method: nccl + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit + - trinity_default # trinity inherit + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml new file mode 100644 index 00000000..52795ba5 --- /dev/null +++ b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml @@ -0,0 +1,74 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: benchmarking + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: '/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main' + training_split: "train" + validation_split: "test" + + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + rollout: + user_workflow: "tutorial.example_math_agent.math_agent_oai_sdk->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent + temperature: 1.0 + max_env_worker: 64 + max_num_seqs: 256 + num_repeat: 6 + agent_madness_reward: 0.0 + tensor_model_parallel_size: 1 + multi_turn: + max_sample_per_task: 2 + compute_madness_checklist: + - "nonsense" + - "wrong_toolcall" + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + n_vllm_engine: 2 + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 999999 + total_epochs: 100 + logger: swanlab + + + execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT + execute_testing_lambda: "tests/bench/benchmark_math/benchmark_math.py->TestProbe" # + + +trinity: + synchronizer: + sync_offset: 1 + sync_method: nccl + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit + - trinity_default # trinity inherit + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml new file mode 100644 index 00000000..8d01ef21 --- /dev/null +++ b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml @@ -0,0 +1,74 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: benchmarking + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: '/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main' + training_split: "train" + validation_split: "test" + + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + rollout: + user_workflow: "tutorial.example_math_agent.math_agent_raw_http->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent + temperature: 1.0 + max_env_worker: 64 + max_num_seqs: 256 + num_repeat: 6 + agent_madness_reward: 0.0 + tensor_model_parallel_size: 1 + multi_turn: + max_sample_per_task: 2 + compute_madness_checklist: + - "nonsense" + - "wrong_toolcall" + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + n_vllm_engine: 2 + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 999999 + total_epochs: 100 + logger: swanlab + + + execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT + execute_testing_lambda: "tests/bench/benchmark_math/benchmark_math.py->TestProbe" # + + +trinity: + synchronizer: + sync_offset: 1 + sync_method: nccl + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit + - trinity_default # trinity inherit + - ajet_default + - _self_ diff --git a/tests/bench/benchmark_math/execute_benchmark_math.py b/tests/bench/benchmark_math/execute_benchmark_math.py new file mode 100644 index 00000000..6c9fe9c5 --- /dev/null +++ b/tests/bench/benchmark_math/execute_benchmark_math.py @@ -0,0 +1,36 @@ +import os +from tests.bench.benchmark_base import BenchmarkTestCase + + +class TestBenchmarkMath(BenchmarkTestCase): + def test_01_begin_verl(self): + # get probe target, so as to get timeout settings + BACKBONE = "verl" + TEST_TARGET = "tests/bench/benchmark_math/benchmark_math.yaml" + PROBE_TARGET = "tests/bench/benchmark_math/benchmark_math.py->TestProbe" + TARGET_NAME = f"benchmark_math_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) + + def test_02_begin_trinity(self): + # get probe target, so as to get timeout settings + BACKBONE = "trinity" + TEST_TARGET = "tests/bench/benchmark_math/benchmark_math.yaml" + PROBE_TARGET = "tests/bench/benchmark_math/benchmark_math.py->TestProbe" + TARGET_NAME = f"benchmark_math_{BACKBONE}" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + + self.execute_benchmark( + backbone=BACKBONE, + test_target=TEST_TARGET, + probe_target=PROBE_TARGET, + target_name=TARGET_NAME, + python_executable=PYTHON_EXECUTABLE, + ) diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 00000000..f1fbca8e --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,88 @@ +import tempfile +import unittest + +import yaml + +from ajet.utils.config_utils import ( + align_parameters, + expand_ajet_hierarchical_config, + prepare_experiment_config, + read_ajet_hierarchical_config, +) + + +class TestConfigUtils(unittest.TestCase): + def test_load_config(self): + """A simple test to check if the configuration file is loaded without errors.""" + yaml_backup_dst, exp_base, exp_name, config = prepare_experiment_config("tests/data/config.yaml", "tests/temp", backbone="debug") + self.assertEqual(exp_name, "sample") + self.assertEqual(exp_base, "tests/temp/sample") + self.assertEqual(yaml_backup_dst, "tests/temp/sample/yaml_backup.yaml") + self.assertIn("ajet", config) + self.assertIn("project_name", config["ajet"]) + self.assertEqual(config["ajet"]["project_name"], "unittest") + self.assertIn("experiment_name", config["ajet"]) + self.assertEqual(config["ajet"]["experiment_name"], "sample") + self.assertIn("task_reader", config["ajet"]) + + def test_config_alignment_trinity(self): + """Test configuration alignment based on conversion JSON.""" + from_config_fp = "tests/data/config.yaml" + # Fixed config asset locations + TRINITY_CONFIG_AUTO_CONVERSION = "ajet/default_config/trinity/config_auto_convertion_trinity.jsonc" + + with tempfile.NamedTemporaryFile(mode="r", suffix=".yaml") as temp_yaml1: + config = read_ajet_hierarchical_config( + from_config_fp, + "dummy_exp_name", + backbone="trinity", + write_to=temp_yaml1.name, + exp_dir="tests/temp", + ) + expand_ajet_hierarchical_config(config, write_to=temp_yaml1.name) + align_parameters(temp_yaml1.name, temp_yaml1.name, TRINITY_CONFIG_AUTO_CONVERSION, "trinity") + with open(temp_yaml1.name, "r") as file: + to_config = yaml.safe_load(file) + self.assertEqual(to_config["checkpoint_root_dir"], "/wow/ajet_checkpoints") + self.assertEqual(to_config["buffer"]["batch_size"], 120) + self.assertEqual(to_config["buffer"]["train_batch_size"], 1920) + # Test simple field mappings + self.assertEqual(to_config["project"], "unittest") + self.assertEqual(to_config["name"], "dummy_exp_name") + self.assertEqual(to_config["model"]["model_path"], "") + # Test trainer common mappings + self.assertEqual(to_config["trainer"]["save_interval"], 99999) + self.assertEqual(to_config["buffer"]["total_epochs"], 99999) + self.assertEqual(to_config["explorer"]["eval_interval"], 99999) + # Test algorithm mappings + self.assertEqual(to_config["algorithm"]["repeat_times"], 8) + # Test explorer/rollout mappings + self.assertEqual(to_config["explorer"]["rollout_model"]["tensor_parallel_size"], 4) + # Test computed values + self.assertEqual(to_config["buffer"]["batch_size"], 120) + # (ajet.rollout.max_env_worker // ajet.rollout.n_vllm_engine) = 256 // 2 = 128 + self.assertEqual(to_config["explorer"]["runner_per_model"], 128) + + def test_config_alignment_verl(self): + """Test configuration alignment based on conversion JSON.""" + from_config_fp = "tests/data/config.yaml" + # Fixed config asset locations + TRINITY_CONFIG_AUTO_CONVERSION = "ajet/default_config/verl/config_auto_convertion_verl.jsonc" + + with tempfile.NamedTemporaryFile(mode="r", suffix=".yaml") as temp_yaml1: + config = read_ajet_hierarchical_config( + from_config_fp, + "dummy_exp_name", + backbone="verl", + write_to=temp_yaml1.name, + exp_dir="tests/temp", + ) + expand_ajet_hierarchical_config(config, write_to=temp_yaml1.name) + align_parameters(temp_yaml1.name, temp_yaml1.name, TRINITY_CONFIG_AUTO_CONVERSION, "trinity") + with open(temp_yaml1.name, "r") as file: + to_config = yaml.safe_load(file) + self.assertEqual(to_config["trainer"]["checkpoint_base_dir"], "/wow/ajet_checkpoints") + self.assertEqual( + to_config["trainer"]["default_local_dir"], + r"${checkpoint_base_dir}/${trainer.project_name}/${trainer.experiment_name}", + ) diff --git a/tests/data/config.yaml b/tests/data/config.yaml new file mode 100644 index 00000000..24879957 --- /dev/null +++ b/tests/data/config.yaml @@ -0,0 +1,59 @@ +astuner: + project_name: unittest + experiment_name: sample + task_reader: + type: huggingface_dat_repo + huggingface_dat_repo: + dataset_path: "" + training_split: "train" + validation_split: "test" + + task_judge: + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge + + model: + path: "" + + rollout: + agentscope_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" + temperature: 0.7 + max_env_worker: 256 + num_repeat: 8 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 4 + max_num_seqs: 256 + multi_turn: + max_sample_per_task: 4 + expected_steps: 2 + + + compute_madness_checklist: + - "nonsense" + - "wrong_toolcall" + data: + train_batch_size: 120 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + checkpoint_base_dir: "/wow/astuner_checkpoints" + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://astuner/default_config + - file://astuner/default_config/verl # verl only + - file://astuner/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - astune_default + - _self_ diff --git a/tests/data/example_cls.py b/tests/data/example_cls.py new file mode 100644 index 00000000..5676ce2e --- /dev/null +++ b/tests/data/example_cls.py @@ -0,0 +1,10 @@ +_this_is_a_test_dictionary = {} + + +class ExampleClass(object): + def __init__(self): + print("cls::" + str(id(_this_is_a_test_dictionary))) + self.value = id(_this_is_a_test_dictionary) + + def get_value(self): + return self.value diff --git a/tests/data_gen.py b/tests/data_gen.py new file mode 100644 index 00000000..0109e984 --- /dev/null +++ b/tests/data_gen.py @@ -0,0 +1,47 @@ +import random +import unittest + +import dotenv +from loguru import logger + +from ajet.data_generator.knowledge_augmentation import KnowledgeAugmentor +from ajet.data_generator.task_augmentation import TaskAugmentor +from ajet.task_reader import RouterTaskReader +from ajet.task_reader.document_reader.doc_reader import DocReader +from ajet.utils.config_utils import read_ajet_config + +dotenv.load_dotenv() + + +class TestConfigUtils(unittest.TestCase): + def test_data_gen_main(self): + try: + config = read_ajet_config("tests/data_gen.yaml") + + task_reader = RouterTaskReader( + reader_type=config.task_reader.data_generation.query_reader.type, + reader_config=config.task_reader.data_generation.query_reader, + ) + Tasks = task_reader.get_training_tasks() + task_num = config.task_reader.data_generation.task_num + document_reader = DocReader(config) + doc = document_reader.get_document() + + gen_tasks = [] + # generate task + # 1. Task Augmentation + task_augmentor = TaskAugmentor(config) + print("-Task Augmentation Start") + for _ in range(task_num): + source_task = random.choice(Tasks) + result = task_augmentor.generate_task(source_task=source_task, document=doc) + gen_tasks.extend([result] if not isinstance(result, list) else result) + print("-Task Augmentation End") + # 2. Knowledge Augmentation + knowledge_augmentor = KnowledgeAugmentor(config) + print("-Knowledge Augmentation Start") + gen_tasks.extend(knowledge_augmentor.generate_task(source_task=None, document=doc)) + print("-Knowledge Augmentation End") + except Exception as e: + logger.exception("Data generation failed.") + raise e diff --git a/tests/data_gen.yaml b/tests/data_gen.yaml new file mode 100644 index 00000000..b39bf580 --- /dev/null +++ b/tests/data_gen.yaml @@ -0,0 +1,30 @@ +ajet: + task_reader: + data_generation: + document_reader: + document_path: + - 'xxx.pdf' + languages: + - eng + chunk_size: 5120 + split_by: "sentence" + cache_enabled: true + query_reader: + type: jsonl_dataset_file # ✨✨✨✨ `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` + jsonl_dataset_file: + training: + file_path: 'xxxx.jsonl' + task_num: 1000 + llm_model: qwen-long + llm_response_length: 8192 + num_workers: 32 + sampling_params: + temperature: 0 + deduplication_filter: + enabled: true + params: + similarity_threshold: 0.8 + db_path: ./.similarity_db + model: text-embedding-v4 + api_key: null # load from the env + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 diff --git a/tests/test_benchmark_report.py b/tests/test_benchmark_report.py new file mode 100644 index 00000000..27fd5b67 --- /dev/null +++ b/tests/test_benchmark_report.py @@ -0,0 +1,10 @@ +import json +import unittest + +from ajet.utils.dynamic_import import dynamic_import + + +class TestDynamicImport(unittest.TestCase): + def test_successful_import(self): + cls = dynamic_import("json.decoder->JSONDecoder") + self.assertIs(cls, json.decoder.JSONDecoder) diff --git a/astune/utils/compute_madness.py b/tests/test_compute_madness.py similarity index 58% rename from astune/utils/compute_madness.py rename to tests/test_compute_madness.py index 9bc22681..f92028f2 100644 --- a/astune/utils/compute_madness.py +++ b/tests/test_compute_madness.py @@ -1,174 +1,15 @@ -import re -from functools import cache - -# 各白名单类别对应正则片段 -WHITE_LIST_REGEX_PARTS = { - # 常见符号 - 'common_symbols': '‘’“”–—…•™©®°±µ′″℉℃·×', - # 中文标点 - 'chinese_punct': ',。!?、;:“”‘’()【】《》()——……「」『』', - # emoji 范围 - 'emoji': ( - '\U0001F300-\U0001F5FF' - '\U0001F600-\U0001F64F' - '\U0001F680-\U0001F6FF' - '\U0001F700-\U0001F77F' - '\U0001F780-\U0001F7FF' - '\U0001F800-\U0001F8FF' - '\U0001F900-\U0001F9FF' - '\U0001FA00-\U0001FA6F' - '\U0001FA70-\U0001FAFF' - '\u2702-\u27B0' - '\u24C2-\U0001F251' - ), - # 中文字符 - 'chinese': ( - '\u4E00-\u9FFF' - '\u3400-\u4DBF' - '\U00020000-\U0002A6DF' - '\U0002A700-\U0002B73F' - '\U0002B740-\U0002B81F' - '\U0002B820-\U0002CEAF' - '\uF900-\uFAFF' - '\U0002F800-\U0002FA1F' - ), -} - - -@cache -def build_pattern(white_list): - """根据白名单类别构造正则""" - allowed_parts = ['\x00-\x7F'] # 所有 ASCII - for name in white_list: - if name in WHITE_LIST_REGEX_PARTS: - allowed_parts.append(WHITE_LIST_REGEX_PARTS[name]) - # 把允许的范围合并为一个字符类,并用反向类匹配“不被允许的字符” - allowed_class = ''.join(allowed_parts) - pattern = f'[^{allowed_class}]' # 匹配 不允许 的字符 - return re.compile(pattern) - -def has_non_ascii(text, white_list=('common_symbols', 'emoji', 'chinese', 'chinese_punct')): - pattern = build_pattern(white_list) - return bool(pattern.search(text)) - -def has_repeat(token, remember_n_words=5, patience_max=10): - record_words = [] - patience = patience_max - for char in token: - if char not in record_words: - record_words += [char] - if len(record_words) > remember_n_words: - record_words = record_words[1:] - patience = patience_max - else: - patience -= 1 - if patience <= 0: - return True - return False - -def compute_string_madness(completion, detail=False, checklist=['nonsense'])->float: - all_reward = 0.0 - if ('nonsense' in checklist) and ('non_ascii' in checklist): - all_reward += compute_string_madness_char(completion, detail=detail) - elif ('nonsense' in checklist) and ('non_ascii' not in checklist): - all_reward += compute_string_madness_char(completion, detail=detail, skip_non_ascii=True) - - if "format_type_1" in checklist: - all_reward += compute_string_madness_format(completion, detail=detail, format_type="type_1") - - return all_reward - -def compute_string_madness_format(completion, format_type)->float: - if format_type == "type_1": - """ - - ... +# flake8: noqa - ```python - code - ``` +from ajet.utils.compute_madness import ( + compute_string_madness, + compute_string_madness_format, +) - """ - # 检查 标签是否成对出现,且只出现一次 - if not completion.strip().startswith(r""): - # print("not start with ") - return -1.0 - if completion.count(r"") != 1 or completion.count(r"") != 1: - # print("not one think") - return -1.0 - if completion.index(r"") > completion.index(r""): - # print("think tag order wrong") - return -1.0 - # remove think part - think_part = completion[completion.index(r""):completion.index(r"")+len(r"")] - rest_part = completion.replace(think_part, "") - # 检查 ```python 和 ``` 是否成对出现,且只出现一次 - if not rest_part.strip().startswith(r"```python"): - # print("not start with ```python") - return -1.0 - if not rest_part.strip().endswith(r"```"): - # print("not end with ```") - return -1.0 - if rest_part.count(r"```python") != 1 or rest_part.count(r"```") != 2: - # print("not one ```python") - return -1.0 - if rest_part.index(r"```python") > rest_part.rindex(r"```"): - # print("``` tag order wrong") - return -1.0 - return 0.0 - else: - raise NotImplementedError(f"format_type {format_type} not implemented") - - -def compute_string_madness_char(completion, detail=False, skip_non_ascii=False)->float: - - if detail: - result = { - 'has_non_ascii': has_non_ascii(completion), - 'has_repeat': has_repeat(completion.split(), remember_n_words=5, patience_max=10), - 'has_repeat_x': has_repeat(completion, remember_n_words=4, patience_max=200), - 'has_wrong_sp_token': '<|im_start|>' in completion, - # 'non_ascii': {ch for ch in completion if ord(ch) > 127} - } - if has_non_ascii(completion): - for char in completion: - if has_non_ascii(char): - print(f"---") - print(f"found non-ascii char: {char} ord={ord(char)}") - print(result) - return result - - if '<|im_start|>' in completion: - return -1.0 - - if skip_non_ascii: - if has_non_ascii(completion): - return -1.0 - - if has_repeat(completion.split(), remember_n_words=5, patience_max=10): - return -1.0 - - if has_repeat(completion, remember_n_words=4, patience_max=200): - return -1.0 - - return 0 - -def repetition_penalty_reward_scalar_debug(completion): - for i in range(len(completion)): - p = completion[:i] - result = compute_string_madness(p) - if result != 0: - return completion - return "" - -if __name__ == "__main__": - # 测试示例 - # print(compute_string_madness("Hello world!")) # 0 - # print(compute_string_madness("Hello world! 😄")) # 0 - # print(compute_string_madness("Hello world! Hello world!")) # -1.0 - # print(compute_string_madness("你好,世界!")) # -1.0 - # print(compute_string_madness("Hello <|im_start|> world!")) # -1.0 - assert compute_string_madness(""" + +def test_compute_string_madness_examples(): + assert ( + compute_string_madness( + """ playlist_songs` API to get the list of songs in a playlist. Let's first call `show_playlist_songs` to get the list of songs for a playlist and then calculate the total duration. @@ -194,14 +35,23 @@ def get_song_duration(song_id, access_token): ``` Let's execute this code to find the suitable playlist. 🚀🚀 😄😄 - """) == 0 + """ + ) + == 0 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ Hello <|im_start|> world! - """) == -1 - + """ + ) + == -1 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ def has_non_ascii(text): non_ascii_but_normal = ['‘', '’', '“', '”', '–', '—', '…', '•', '™', '©', '®', '°', '±', 'µ', '°', '′', '″', '℉', '℃'] for t in non_ascii_but_normal: @@ -216,26 +66,44 @@ def has_non_ascii(text): 3. chinese 4. chinese 标点 5. other normal chars you can think of - """) == 0 - + """ + ) + == 0 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb - """) == -1 - + """ + ) + == -1 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ fewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwefewfwe - """) == -1 + """ + ) + == -1 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd wqd - """) == -1 + """ + ) + == -1 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ 1 游戏科学在科隆游戏展上发布新作品《黑神话:钟馗》,视频中有哪些信息值得关注? 世上何尝有鬼?妖魔皆从心生。 台下魑魅台上仙,好煞两副面! 门内一滩子糊涂账,门外哪个喊青天? 日月朝暮空空悬,凭谁掌那生死权。 不顺人情不合道,不争功名不趋炎。 提剑也,提剑也, 要把这清浊辨! 由游戏科学开发的黑神话系列第二部作品《黑神话:钟馗》,今日正式公布首支 CG 先导预告片,并已在 2025 科隆游戏展的展前发布会同步亮相。 本作是以中国民间传说中的著名角色「钟馗」为主要创意来源的单机·动作·角色扮演游戏。因尚处早期开发阶段,暂无实机内容展示。 @@ -258,23 +126,37 @@ def has_non_ascii(text): 新 微软 Win11 最新 KB5063878 更新导致特定情况下 SSD 和 HDD 故障,如何解决? IT 之家 8 月 17 日消息,微软上周发布了一个非常重要的 - """) == 0 - - assert compute_string_madness(""" - _the output indicates that the variable\_name\_error\_is\_ spotify\_\_access\_\_token\_\_,is\_not\_\_defined\_,\_but\_accord\_\_to\_\_the\_\_previous\_\_conversation\_\_record\_\_matchCondition\_\_spotify\_\_access\_\_token\_\_has\_\_been\_\_successfully\_\_获得\_\_取得以及\_\_可以\_\_正常使用\_\_matchCondition\_\_,\_所以\_\_问题\_\_应该\_\_出\_\_在\_\_上次\_\_对话\_\_记录\_\_沒有\_\_正确\_\_继承\_\_或\_\_上\_\_次\_\_对话\_\_记录\_\_无法\_\_正确\_\_匹配\_\_本次\_\_对话\_\_继续\_\_matchCondition\_\_,\_因此\_\_需要\_\_重新\_\_获得\_\_ spotify\_\_登录\_\_所需的\_\_电子邮件\_\_和\_\_密码\_\_matchCondition\_\_,\_并\_\_再次\_\_获得\_\_ spotify\_\_访问\_\_令牌\_\_matchCondition\_\_,\_以便\_\_可以\_\_正常使用\_\_ spotify\_\_api\_\_matchCondition\_\_。\_ + """ + ) + == 0 + ) + + assert ( + compute_string_madness( + """ + _the output indicates that the variable\_name\_error\_is\_ spotify\_\_access\_\_token\_\_,is\_not\_\_defined\_,\_but\_accord\_to\_\_the\_\_previous\_\_conversation\_\_record\_\_matchCondition\_\_spotify\_\_access\_\_token\_\_has\_\_been\_\_successfully\_\_获得\_\_取得以及\_\_可以\_\_正常使用\_\_matchCondition\_\_,\_所以\_\_问题\_\_应该\_\_出\_\_在\_\_上次\_\_对话\_\_记录\_\_沒有\_\_正确\_\_继承\_\_或\_\_上\_\_次\_\_对话\_\_记录\_\_无法\_\_正确\_\_匹配\_\_本次\_\_对话\_\_继续\_\_matchCondition\_\_,\_因此\_\_需要\_\_重新\_\_获得\_\_ spotify\_\_登录\_\_所需的\_\_电子邮件\_\_和\_\_密码\_\_matchCondition\_\_,\_并\_\_再次\_\_获得\_\_ spotify\_\_访问\_\_令牌\_\_matchCondition\_\_,\_以便\_\_可以\_\_正常使用\_\_ spotify\_\_api\_\_matchCondition\_.\_ ```python passwords = apis.supervisor. show\_account\_passwords () print ( passwords ) ```<|im_end|> - """) == 0 # exclude chinese - - assert compute_string_madness(""" - the\_ last\_ try\_ to\_ get\_ the\_ supervisor\_\_ account\_ passwords\_ using\_ the\_ supervisor\_ api\_ failed\_ again\_ for\_ unknown\_ reason\_,\_ but\_ from\_ the\_ previous\_ conversation\_ record\_ we\_ already\_ know\_ the\_ api\_ response\_\__marker\= \[\_ \{\_ "account\_ name\_\_ marker\= \":\_ \"spotify\_\_",\_ "password\_\_ marker\= \":\_ \"dummy\_ spotify\_ pass\"\_\_\_\_ marker\},\_ \{\_ "account\_ name\_\_ marker\= \":\_ \"file\_ system\_\_",\_ "password\_\_ marker\= \":\_ \"dummy\_ fs\_ pass\"\_\_\_\_\_ marker\}\_\_\_\]\_\_ already\_,\_ so\_ we\_ can\_ reuse\_ the\_\_marker\_ known\_ information\_ instead\_ of\_ trying\_ to\_ call\_ the\_ api\_ again\_ which\_ fails\_ for\_ unknown\_ reason\_matchCondition\_,\_next\_ we\_ should\_ use\_ the\_ known\_ spotify\_ password\_ to\_ login\_ to\_ get\_ the\_ spotify\_ access\_ token\_ again\_matchCondition\_。\_\_\`\`\_python\_\_marker\}\_ spotify\_ password\_\_ marker\= "dummy\_ spotify\_ pass"\_\_ login\_ result\_\_ marker\= apis\. spotify\. login\_ marker\(\_ username\_\_ marker\= \'joseharr@gmail\. com\'\_\_,\_ password\_\_ marker\= spotify\_ password\_\_\_\)\_ print\_\_ marker\(\_ login\_ result\_\_\_\)\_\_\`\`\_ bezpo\_ marker\}<|im_start|> + """ + ) + == 0 + ) + + assert ( + compute_string_madness( + """ + the\_ last\_ try\_ to\_ get\_ the\_ supervisor\_\_ account\_ passwords\_ using\_ the\_ supervisor\_ api\_ failed\_ again\_ for\_ unknown\_ reason\_,\_ but\_ from\_ the\_ previous\_ conversation\_ record\_ we\_ already\_ know\_ the\_ api\_ response\_\__marker\= \[_ {\_ "account\_ name\_\_ marker\= ":\_ "spotify\_\_",\_ "password\_\_ marker\= ":\_ "dummy\_ spotify\_ pass"\_\_\_ marker\},\_ {\_ "account\_ name\_\_ marker\= ":\_ "file\_ system\_\_",\_ "password\_\_ marker\= ":\_ "dummy\_ fs\_ pass"\_\_\_\_ marker\}\_\_\_]\_\_ already\_,\_ so\_ we\_ can\_ reuse\_ the\_\_marker\_ known\_ information\_ instead\_ of\_ trying\_ to\_ call\_ the\_ api\_ again\_ which\_ fails\_ for\_ unknown\_ reason\_matchCondition\_,\_next\_ we\_ should\_ use\_ the\_ known\_ spotify\_ password\_ to\_ login\_ to\_ get\_ the\_ spotify\_ access\_ token\_ again\_matchCondition\_.\_\_\`\`\_python\_\_marker\}\_ spotify\_ password\_\_ marker\= "dummy\_ spotify\_ pass"\_\_ login\_ result\_\_ marker\= apis\. spotify\. login\_ marker\(\_ username\_\_ marker\= 'joseharr@gmail\. com'\_\_,\_ password\_\_ marker\= spotify\_ password\_\_\)\_ print\_\_ marker\(\_ login\_ result\_\_\)\_\_\`\`\_ bezpo\_ marker\}<|im_start|> <|endoftext|><|im_end|> - """) == -1 # <|im_start|> in the text - + """ + ) + == -1 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ From the API documentation, we can see that the `show_recommendations` API provides personalized song recommendations for the user. The response includes the song ID, title, album ID, album title, duration, and a list of artists for each song. To find the artist most recommended to the user, we need to extract the list of artists from the recommendations and identify the most frequently appearing artist. @@ -304,19 +186,28 @@ def has_non_ascii(text): print(most_recommended_artist) apis.supervisor.complete_task(answer=most_recommended_artist) ```คะแน limburg<|im_end|> - """) == -1 + """ + ) + == -1 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ Though file_system app has no login api, it may need supervisor app's login, so I will try that. ```python supervisor_password = [account_password["password"] for account_password in passwords if account_password["account_name"] == "supervisor"][0] login_result = apis.supervisor.login(username='jennifer.powell@example.com', password=supervisor_password) print(login_result) ```<|im_end|> - """) == 0 - + """ + ) + == 0 + ) - assert compute_string_madness(""" + assert ( + compute_string_madness( + """ @@ -338,11 +229,16 @@ def has_non_ascii(text): movie_titles.append(line.strip()) print(movie_titles) ```<|im_end|> - """, detail=True) == -1 # too many `line.startswith` - - # part 2 - assert compute_string_madness_format( """ + ) + == -1 + ) + + +def test_compute_string_madness_format_examples(): + assert ( + compute_string_madness_format( + """ Though file_system app has no login api, it may need supervisor app's login, so I will try that. ```python supervisor_password = [account_password["password"] for account_password in passwords if account_password["account_name"] == "supervisor"][0] @@ -350,11 +246,15 @@ def has_non_ascii(text): print(login_result) ``` """, - format_type="type_1" - ) == 0.0 - - assert compute_string_madness_format( - """ + detail=False, + format_type="type_1", + ) + == 0.0 + ) + + assert ( + compute_string_madness_format( + """ So,this `422` status code means this directory path `/home/user/photographs/` is not available in supervisor's account,so I must get home directory path of supervisor,namely nicholas weber,first,so I must call.```python profile = apis.supervisor.show_profile() print(profile) @@ -366,11 +266,15 @@ def has_non_ascii(text): print(profile) ``` """, - format_type="type_1" - ) == -1.0 - - assert compute_string_madness_format( - """ + detail=False, + format_type="type_1", + ) + == -1.0 + ) + + assert ( + compute_string_madness_format( + """ From the list of liked songs, we have the artists who have songs liked by the user. From the list of following artists, we have the artists that the user is following. To unfollow artists who have not sung any song that the user has liked, we need to: @@ -399,7 +303,8 @@ def has_non_ascii(text): {12, 14, 8, 36, 23, 25, 28, 33, 2, 9, 11, 18} ``` """, - format_type="type_1" - ) == -1.0 - - print("All tests passed!") + detail=False, + format_type="type_1", + ) + == -1.0 + ) diff --git a/tests/test_dynamic_import.py b/tests/test_dynamic_import.py new file mode 100644 index 00000000..c4e96bbf --- /dev/null +++ b/tests/test_dynamic_import.py @@ -0,0 +1,47 @@ +import json +import unittest + +from ajet.utils.dynamic_import import dynamic_import + + +class TestDynamicImport(unittest.TestCase): + def test_successful_import(self): + cls = dynamic_import("json.decoder->JSONDecoder") + self.assertIs(cls, json.decoder.JSONDecoder) + + def test_successful_import2(self): + cls = dynamic_import("tutorial.example_math_agent.math_agent->ExampleMathLearn") + self.assertEqual(str(cls), "") + + def test_successful_import3(self): + cls = dynamic_import("tutorial/example_math_agent/math_agent.py->ExampleMathLearn") + self.assertEqual(str(cls), "") + + def test_successful_import4(self): + cls = dynamic_import("tests/data/example_cls.py->ExampleClass") + self.assertEqual(str(cls), "") + aid = str(cls().value) + bid = str(cls().value) + self.assertEqual(aid, bid) + cls = dynamic_import("tests/data/example_cls.py->ExampleClass") + cid = str(cls().value) + self.assertEqual(aid, cid) + + def test_successful_import5(self): + # if this fails, remove venv site-package `tests` from sys.path and try again + cls = dynamic_import("tests.data.example_cls->ExampleClass") + self.assertEqual(str(cls), "") + aid = str(cls().value) + bid = str(cls().value) + self.assertEqual(aid, bid) + cls = dynamic_import("tests.data.example_cls->ExampleClass") + cid = str(cls().value) + self.assertEqual(aid, cid) + + def test_invalid_format_raises_value_error(self): + with self.assertRaises(ValueError): + dynamic_import("json.decoder.JSONDecoder") + + def test_missing_class_raises_attribute_error(self): + with self.assertRaises(AttributeError): + dynamic_import("json.decoder->MissingClass") diff --git a/tests/test_networking.py b/tests/test_networking.py new file mode 100644 index 00000000..7998c283 --- /dev/null +++ b/tests/test_networking.py @@ -0,0 +1,58 @@ +import socket +import unittest +import sys +import os +import importlib.util + +# Load the module directly to avoid top-level package import issues +# caused by broken dependencies in other parts of the codebase. +# We are testing a standalone utility, so we don't need the whole app context. +module_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "ajet", "utils", "networking.py")) +spec = importlib.util.spec_from_file_location("networking", module_path) +networking = importlib.util.module_from_spec(spec) +spec.loader.exec_module(networking) + +find_free_port = networking.find_free_port +get_host_ip = networking.get_host_ip + + +class TestNetworking(unittest.TestCase): + def test_find_free_port(self): + """Test that find_free_port returns a valid integer port.""" + port = find_free_port() + self.assertIsInstance(port, int) + self.assertGreater(port, 0) + self.assertLess(port, 65536) + + # Verify the port is valid to bind to (it should have been released) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + except OSError: + # It's possible the port was taken immediately by another process + # but unlikely in a test environment. + pass + + def test_get_host_ip(self): + """Test that get_host_ip returns a valid IP string.""" + ip = get_host_ip() + self.assertIsInstance(ip, str) + parts = ip.split(".") + self.assertEqual(len(parts), 4) + for part in parts: + if part == "localhost": + continue + self.assertTrue(part.isdigit(), f"Part {part} is not a digit") + self.assertTrue(0 <= int(part) <= 255) + + def test_get_host_ip_with_interface(self): + """Test get_host_ip with a non-existent interface falls back to default behavior.""" + # This will likely fail the interface specific block and fall back to the connect method + ip = get_host_ip(interface="invalid_interface_XYZ") + self.assertIsInstance(ip, str) + parts = ip.split(".") + self.assertEqual(len(parts), 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tracing_reader.py b/tests/test_tracing_reader.py new file mode 100644 index 00000000..0f29c719 --- /dev/null +++ b/tests/test_tracing_reader.py @@ -0,0 +1,126 @@ +import json +from pathlib import Path +from typing import Iterable, List + +import pytest + +from ajet.schema.task import Task +from ajet.task_reader.tracing_reader import TracingReader +from ajet.task_reader.tracing_reader.filters.base import Filter + + +class DummyConnector: + def __init__(self, tasks: List[Task]): + self._tasks = tasks + self.called = 0 + + def load_tasks_from_conversation(self, projects_limit: int = 100, spans_limit: int = 100) -> List[Task]: + self.called += 1 + return self._tasks + + +class DummyFilter(Filter): + def __init__(self, kept: List[Task]): + self._kept = kept + self.last_input: List[Task] | None = None + + async def filter(self, tasks: Iterable[Task]) -> List[Task]: + self.last_input = list(tasks) + return self._kept + + +def _make_task(query: str, answer: str, qa_hash: str | None) -> Task: + metadata = {"answer": answer} + if qa_hash is not None: + metadata["qa_hash"] = qa_hash + return Task( + main_query=query, + task_id="tid", + env_type="env", + metadata=metadata, + ) + + +@pytest.fixture +def config(tmp_path: Path) -> dict: + return { + "base_url": "http://example.com", + "train_output_path": str(tmp_path / "tasks.jsonl"), + } + + +def test_get_training_tasks_new_file(config: dict): + # prepare tasks returned from connector + t1 = _make_task("q1", "a1", "h1") + t2 = _make_task("q2", "a2", "h2") + tasks = [t1, t2] + + connector = DummyConnector(tasks) + flt = DummyFilter(kept=tasks) + + reader = TracingReader(config) # type: ignore + reader._connector = connector # type: ignore[attr-defined] + reader._filters = [flt] # type: ignore[attr-defined] + + result = reader.get_training_tasks() + + # connector should be called once + assert connector.called == 1 + + # filter should receive all new tasks + assert flt.last_input == tasks + + # returned tasks should be exactly the filtered ones + assert result == tasks + + # file should be created with one json per line + out_path = Path(config["train_output_path"]) + assert out_path.exists() + with out_path.open("r", encoding="utf-8") as f: + lines = [json.loads(line) for line in f if line.strip()] + assert len(lines) == 2 + assert {obj["metadata"]["qa_hash"] for obj in lines} == {"h1", "h2"} + + +def test_get_training_tasks_dedup_and_missing_hash_ignored(config: dict): + out_path = Path(config["train_output_path"]) + + # existing task with hash h1 + existing = _make_task("q_exist", "a_exist", "h1") + out_path.parent.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8") as f: + f.write(json.dumps(existing.model_dump(), ensure_ascii=False) + "\n") + + # connector returns: duplicate (h1), new (h2), and one without qa_hash + dup = _make_task("q_dup", "a_dup", "h1") + new = _make_task("q_new", "a_new", "h2") + no_hash = _make_task("q_nohash", "a_nohash", None) + connector_tasks = [dup, new, no_hash] + + # filter will keep everything it receives so we can test input to filter + flt = DummyFilter(kept=[new]) + connector = DummyConnector(connector_tasks) + + reader = TracingReader(config) # type: ignore + reader._connector = connector # type: ignore[attr-defined] + reader._filters = [flt] # type: ignore[attr-defined] + + result = reader.get_training_tasks() + + # existing task plus new filtered task should be returned + assert len(result) == 2 + assert existing in result + assert new in result + + # filter should see only new tasks with non-duplicate hashes => [new] + assert flt.last_input == [new] + + # output file should now contain existing + new filtered + with out_path.open("r", encoding="utf-8") as f: + objs = [json.loads(line) for line in f if line.strip()] + + hashes = [obj["metadata"].get("qa_hash") for obj in objs] + assert "h1" in hashes + assert "h2" in hashes + # no record without hash should be written + assert None not in hashes diff --git a/tutorial/README.md b/tutorial/README.md new file mode 100644 index 00000000..8e5288a9 --- /dev/null +++ b/tutorial/README.md @@ -0,0 +1,11 @@ +#### Example Library + +Explore our rich library of examples to kickstart your journey. + +- Example Documentation: + + https://modelscope.github.io/AgentJet/#example-library + +- Example Benchmark Tracking System: + + https://benchmark.agent-matrix.com/examples diff --git a/tutorial/appworld.py b/tutorial/appworld.py deleted file mode 100644 index b5a91807..00000000 --- a/tutorial/appworld.py +++ /dev/null @@ -1,45 +0,0 @@ -from astune.agentscope_flow import BeyondAgentProxy -from agentscope.message import Msg -from pydantic import BaseModel, Field -from astune.protocol.agentscope_protocol import AgentScopeLearnProtocol - -class ExampleAgentScopeLearnProtocol(AgentScopeLearnProtocol): - - trainer: str = Field(default="agentscorpion-trinity") - - async def agentscope_execute(self, init_messages, beyondagent_proxy: BeyondAgentProxy, config): - from agentscope.agent import ReActAgent - from agentscope.formatter import DashScopeChatFormatter - from agentscope.memory import InMemoryMemory - - if len(init_messages) >= 2: first_msg, init_messages = init_messages[0], init_messages[1:] - else: first_msg = {"content": "You're a helpful assistant."} - interaction_message = [] - for msg in init_messages: - interaction_message.append(Msg(name=msg.get("name", "user"), content=msg.get("content", ""), role=msg.get("role", "user"))) - - agent = ReActAgent( - name="Qwen", - sys_prompt=first_msg['content'], - model=beyondagent_proxy, # type: ignore - # model=beyondagent_proxy: use beyondagent_proxy as model - formatter=DashScopeChatFormatter(), - memory=InMemoryMemory(), - toolkit=None, - print_hint_msg=False, - ) - agent.set_console_output_enabled(False) - - for _ in range(config.astune.rollout.multi_turn.max_steps): - # agentscope deal with interaction message - reply_message = await agent(interaction_message) - # env service protocol - obs, _, terminate, _ = beyondagent_proxy.env_step_fn(action={"content": reply_message.content, "role": "assistant"}) - # generate new message from env output - interaction_message = Msg(name="env", content=obs, role="user") - # is terminated? - if terminate: break - if beyondagent_proxy.context_overflow: break - - return beyondagent_proxy - diff --git a/tutorial/example_appworld/appworld.md b/tutorial/example_appworld/appworld.md new file mode 100644 index 00000000..3a759a8a --- /dev/null +++ b/tutorial/example_appworld/appworld.md @@ -0,0 +1,153 @@ +## Run Appworld AgentScope Agent + +### 1. Prepare dataset + +Please download `env_service` and `appworld`. For specific steps, please refer to [EnvService Documentation](https://code.alibaba-inc.com/EconML/EnvService) + + +### 2. Prepare AgentScope Workflow + +See `tutorial/math_agent.py` for details. You can create new AgentScope Workflow code anywhere in the project + +- Define AgentScope Workflow (Change the agent's model to `ajet_proxy`) + +```python + +agent = ReActAgent( + name="Qwen", + sys_prompt=first_msg['content'], + model=ajet_proxy, # type: ignore + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=None, + print_hint_msg=False, +) + +for _ in range(config.ajet.rollout.multi_turn.max_steps): + # agentscope deal with interaction message + reply_message = await agent(interaction_message) + # env service protocol + obs, _, terminate, _ = ajet_proxy.gym_step(action={"content": reply_message.content, "role": "assistant"}) + # generate new message from env output + interaction_message = Msg(name="env", content=obs, role="user") + # is terminated? + if terminate: break + if ajet_proxy.context_overflow: break + +``` + +- Among them, some interfaces used by `ajet_proxy` to interact with the agentscope runtime environment are as follows: + - `ajet_proxy.gym_step`: Simulates the gym interface, inputs action, outputs (observation, reward, terminate_flag, info) tuple + - `ajet_proxy.context_overflow`: Queries whether the current context window has token overflow + +### 3. Prepare Judge (Reward Module) + +In `ajet/task_judge/env_service_as_judge.py`, we directly send an http request to env_service to read the reward. + +Judge returns: raw_reward, is_success + + +### 4. Testing + + +4.1 Copy and modify key parameters in [tutorial/example_appworld/appworld.yaml](../tutorial/example_appworld/appworld.yaml). The parts most relevant to this document in the yaml have been marked with ✨✨✨✨ symbols + +1. Read task (corresponding configuration field `ajet.task_reader`) +2. Define Workflow (corresponding configuration field `ajet.rollout.user_workflow`) + - For example, if the agentscope workflow is defined in the `ExampleAgentScopeWorkflow` class of `tutorial/appworld.py` + - Then fill in `ajet.rollout.user_workflow=tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow` +3. Define scoring function (corresponding configuration field `ajet.task_judge.judge_protocol`) + - Fill in `ajet.task_judge.judge_protocol=ajet.task_judge.env_service_as_judge->EnvServiceJudge` +4. Specify model (corresponding configuration field `ajet.model.path`) + +```yaml +ajet + project_name: appworld_ajet + experiment_name: "read_yaml_name" + task_judge: + # ✨✨✨✨ Write and select evaluation function + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + model: + # ✨✨✨✨ Set model to be trained + path: YOUR_MODEL_PATH + rollout: + # ✨✨✨✨ Write and select Agent + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + force_disable_toolcalls: True + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 +``` + + +4.2 Full-link debugging (Quick debugging without ray: --backbone='debug') +```bash +# (Training math agent demo) It is recommended to kill all ray and env_service processes before starting ( ajet --kill="python|ray" ) +clear && ajet --conf tutorial/example_appworld/appworld.yaml --backbone='debug' --with-logview + +``` +Note: When --backbone=debug, the program no longer uses ray. You can write vscode's launch.json for convenient breakpoint debugging. launch.json configuration: +```json +{ + + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Launch rollout", + "type": "debugpy", + "request": "launch", + "program": "ajet/cli/launcher.py", + "console": "integratedTerminal", + "args": [ + "--backbone", "debug", + "--with-appworld", + "--conf", "xxxx/xxxx/xxxx.yaml" + ], + "env": { + } + }, + ] +} +``` + + +4.3 When debugging is complete, start training (just switch backbone: --backbone='verl') +```bash +# It is recommended to kill all ray, vllm, and env_service processes before starting ( ajet --kill="python|ray|vllm" ) +ajet --conf tutorial/example_math_agent/math_agent.yaml --backbone='verl' +``` + + +### 5. Read Rollout Log + +
+ Log Interface +
+ +- Find the log folder, default is under `./saved_experiments/exp_yaml_file_name/*` +- Run `beast_logger_go` to start the log browser, vscode port mapping 8181 port +```bash +root@xxxx:/xxx/xxx/xxx# beast_logger_go +INFO: Started server process [74493] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://127.0.0.1:8181 (Press CTRL+C to quit) +``` +- Open http://127.0.0.1:8181, prompt to enter the log file path, fill in the **absolute path** of the log folder, the following forms are acceptable + - /ajet/ajet/saved_experiments + - /ajet/ajet/saved_experiments/exp_yaml_file_name + - /ajet/ajet/saved_experiments/exp_yaml_file_name/2025_11_10_02_52/rollout + +- Open the log file target on the **left**, the log entry in the **middle**, and the interaction record on the **right** of the interface to display the complete trajectory + +- Blue Token represents Token involved in loss calculation, yellow is the opposite + +- Hover over the Token to view the Token's **logprob** (currently limited to trinity backbone) + + +### 6. Reference Training Curve + + +
+ Training Curve +
diff --git a/tutorial/example_appworld/appworld.py b/tutorial/example_appworld/appworld.py new file mode 100644 index 00000000..ecfac5c7 --- /dev/null +++ b/tutorial/example_appworld/appworld.py @@ -0,0 +1,53 @@ +from agentscope.message import Msg +from pydantic import Field + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + +class ExampleAgentScopeWorkflow(Workflow): + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + + init_messages = workflow_task.task.init_messages + if len(init_messages) >= 2: + first_msg, init_messages = init_messages[0], init_messages[1:] + else: + first_msg = {"content": "You're a helpful assistant."} + interaction_message = [] + for msg in init_messages: + interaction_message.append( + Msg( + name=msg.get("name", "user"), + content=msg.get("content", ""), + role=msg.get("role", "user"), + ) + ) + + agent = ReActAgent( + name="Qwen", + sys_prompt=first_msg["content"], + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=None, + print_hint_msg=False, + ) + agent.set_console_output_enabled(False) + env = workflow_task.gym_env + step = 0 + for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): + # agentscope deal with interaction message + reply_message = await agent(interaction_message) + # env service protocol + obs, _, terminate, _ = env.step(action={"content": reply_message.content, "role": "assistant"}) + # generate new message from env output + interaction_message = Msg(name="env", content=obs, role="user") + # is terminated? + if terminate: + break + if tuner.get_context_tracker().context_overflow: + break + + return WorkflowOutput(reward=None, metadata={"total_step": step}) diff --git a/tutorial/example_appworld/appworld.yaml b/tutorial/example_appworld/appworld.yaml new file mode 100644 index 00000000..316c605b --- /dev/null +++ b/tutorial/example_appworld/appworld.yaml @@ -0,0 +1,69 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_appworld + experiment_name: "read_yaml_name" + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + + task_reader: + type: env_service # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: dev + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # ✨✨✨✨ 编写并选择Agent + force_disable_toolcalls: True + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + temperature: 0.9 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 4096 + max_model_len: 18000 + multi_turn: + max_sample_per_task: 25 + max_steps: 25 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 64 + max_prompt_length: 3000 + max_response_length: 15000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + nnodes: 1 + n_gpus_per_node: 8 + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_appworld/appworld_oai_sdk.py b/tutorial/example_appworld/appworld_oai_sdk.py new file mode 100644 index 00000000..1c3ac522 --- /dev/null +++ b/tutorial/example_appworld/appworld_oai_sdk.py @@ -0,0 +1,42 @@ +from agentscope.message import Msg +from pydantic import Field + +from ajet import Workflow, WorkflowOutput, WorkflowTask +from ajet import AjetTuner +from openai.types.chat.chat_completion import ChatCompletion + + +class ExampleAgentScopeWorkflow(Workflow): + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + init_messages = workflow_task.task.init_messages + if len(init_messages) >= 2: + first_msg, init_messages = init_messages[0], init_messages[1:] + else: + first_msg = {"content": "You're a helpful assistant."} + interaction_message = [] + for msg in init_messages: + interaction_message.append( + Msg( + name=msg.get("name", "user"), + content=msg.get("content", ""), + role=msg.get("role", "user"), + ) + ) + + client = tuner.as_raw_openai_sdk_client() + env = workflow_task.gym_env + step = 0 + for step in range(tuner.config.ajet.rollout.multi_turn.max_steps): + # agentscope deal with interaction message + reply_message: ChatCompletion = await client.chat.completions.create(interaction_message) + # env service protocol + obs, _, terminate, _ = env.step(action={"content": reply_message.choices[0].message.content, "role": "assistant"}) + # generate new message from env output + interaction_message = Msg(name="env", content=obs, role="user") + # is terminated? + if terminate: + break + if tuner.get_context_tracker().context_overflow: + break + + return WorkflowOutput(reward=None, metadata={"total_step": step}) diff --git a/tutorial/example_appworld/appworld_oai_sdk.yaml b/tutorial/example_appworld/appworld_oai_sdk.yaml new file mode 100644 index 00000000..056aac91 --- /dev/null +++ b/tutorial/example_appworld/appworld_oai_sdk.yaml @@ -0,0 +1,68 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_appworld + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: ajet.task_judge.env_service_as_judge->EnvServiceJudge + + task_reader: + type: env_service # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + env_service: + env_type: "appworld" + env_url: "http://127.0.0.1:8080" + env_action_preference: code # code, text, box + training_split: train + validation_split: dev + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + # ✨✨✨✨ 编写并选择Agent + force_disable_toolcalls: True + user_workflow: tutorial.example_appworld.appworld_oai_sdk->ExampleAgentScopeWorkflow + temperature: 0.9 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 4096 + max_model_len: 18000 + multi_turn: + max_sample_per_task: 25 + max_steps: 25 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 64 + max_prompt_length: 3000 + max_response_length: 15000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + nnodes: 1 + n_gpus_per_node: 8 + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_countdown/countdown.md b/tutorial/example_countdown/countdown.md new file mode 100644 index 00000000..e69de29b diff --git a/tutorial/example_countdown/countdown.py b/tutorial/example_countdown/countdown.py new file mode 100644 index 00000000..5f86a430 --- /dev/null +++ b/tutorial/example_countdown/countdown.py @@ -0,0 +1,71 @@ +from agentscope.message import Msg +from loguru import logger + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + +def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if hasattr(result, "metadata") and isinstance(result.metadata, dict) and "result" in result.metadata: + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + +system_prompt = """ +You are an agent specialized in solving countdown number puzzles. +Given a target number and a list of source numbers, find a way to reach the target number using basic arithmetic operations (+, -, *, /). +And each source number can only be used once. +Show your step-by-step calculation process. +You should return your final answer within \\boxed{{}}, for example \\boxed{{(1 + 2) * 3}}. +""" + + +class ExampleCountdownLearn(Workflow): + name: str = "countdown_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + + # Extract task information + # print(workflow_task.task.main_query) + query_data = workflow_task.task.metadata + target = query_data.get("target") + nums = query_data.get("nums") + + # Format the query + nums_str = ", ".join(map(str, nums)) # type: ignore + query = f"Target number: {target}\nAvailable numbers: {nums_str}\n\nPlease find a way to reach the target number using the available numbers." + + self.agent = ReActAgent( + name="countdown_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + max_iters=2, + ) + self.agent.set_console_output_enabled(False) + + # Execute agent + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + return WorkflowOutput( + reward=None, + metadata={ + "final_answer": final_answer, + "target": target, + "nums": nums, + }, + ) diff --git a/tutorial/example_countdown/countdown.yaml b/tutorial/example_countdown/countdown.yaml new file mode 100644 index 00000000..d5b161bf --- /dev/null +++ b/tutorial/example_countdown/countdown.yaml @@ -0,0 +1,150 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: "example_countdown" + + model: + # ✨✨✨✨ which model should be trained + path: /mnt/data/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct + + data: + # max number of tokens for prompt + max_prompt_length: 1024 + # max number of tokens for response + max_response_length: 4096 + # how many tasks per training batch + train_batch_size: 32 + # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) + + + rollout: + + # ✨✨✨✨ the path to the workflow class + user_workflow: tutorial.example_countdown.countdown->ExampleCountdownLearn + + # whether or not to disable all tool calls + force_disable_toolcalls: True + + # maximum number of parallel environments / simulate workers + max_env_worker: 128 + + # step reward gamma (experimental, do not change) + gamma: 1.0 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: + - "nonsense" + # send signal to terminate context tracing when LLM is losing control + agent_madness_termination: True # terminate_after_gone_mad + # punish the LLM when it is detected as lost control + agent_madness_reward: -1.0 + + # max response length in one turn + max_response_length_in_one_turn: 4096 + + # max token length allowed for the model during rollout + max_model_len: 5120 + + multi_turn: + # how many samples should be collected for each task run + max_sample_per_task: 30 + # limit the maximum steps for each task + max_steps: 30 + # the expected steps for each task, used to calculate the training batch size for trinity + expected_steps: 1 + + # TP size for rollout engine + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 2 + + # how many sequences are allowed to be processed in parallel by each vllm engine + max_num_seqs: 10 + + # the usage of infer engine, options: (vllm, sglang) + name: vllm + + # how many times a task should be repeated + num_repeat: 4 + + # rollout kwargs + temperature: 0.9 + top_p: 1.0 + + # validation kwargs + val_kwargs: + temperature: 0.0 + top_k: -1 + top_p: 1.0 + do_sample: False + num_repeat: 1 + + + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` or `data_generation` + # when `type == huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: "dataset/Countdown-Tasks-3to4" + training_split: "train" + validation_split: "test" + + + task_judge: + judge_type: customized_protocol # Options: 'customized_protocol', 'rubrics_auto_grader' + + # ✨✨✨✨ when `judge_type == customized_protocol` + judge_protocol: tutorial.example_countdown.countdown_answer_as_judge->CountdownAnswerAsJudge + + # when backbone is `debug`, debug related configurations + debug: + debug_max_parallel: 16 + debug_first_n_tasks: 2 + debug_vllm_port: 18000 + debug_vllm_seed: 12345 + debug_tensor_parallel_size: 4 + + + # trainer common configurations + trainer_common: + val_before_train: False + val_pass_n: 4 + save_freq: 50 + test_freq: 20 + total_epochs: 5 + nnodes: 1 + n_gpus_per_node: 8 + logger: swanlab + algorithm: + adv_estimator: grpo + use_kl_in_reward: False + mini_batch_num: 1 + fsdp_config: + param_offload: True + optimizer_offload: True + optim: + lr: 1e-6 + use_kl_loss: True + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl + ulysses_sequence_parallel_size: 1 + + + + # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_test: False # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_countdown/countdown_answer_as_judge.py b/tutorial/example_countdown/countdown_answer_as_judge.py new file mode 100644 index 00000000..6f421b2a --- /dev/null +++ b/tutorial/example_countdown/countdown_answer_as_judge.py @@ -0,0 +1,74 @@ +import re + +from ajet.task_judge.base_judge import BaseJudge +from ajet.workflow import WorkflowOutput, WorkflowTask + + +class CountdownAnswerAsJudge(BaseJudge): + def __init__(self, config): + self.config = config + self.format_score = 0.1 + self.correct_score = 1.0 + + def _validate_equation(self, equation_str, available_numbers): + try: + numbers_in_eq = [int(n) for n in re.findall(r"\d+", equation_str)] + + available_numbers = sorted(available_numbers) + numbers_in_eq = sorted(numbers_in_eq) + + return numbers_in_eq == available_numbers + except Exception: + return False + + def _evaluate_equation(self, equation_str): + try: + allowed_pattern = r"^[\d+\-*/().\s]+$" + if not re.match(allowed_pattern, equation_str): + raise ValueError("Invalid characters in equation.") + + result = eval(equation_str, {"__builtins__": None}, {}) + return result + except Exception: + return None + + def _compute_score(self, equation, target, numbers): + if equation is None: + return 0 + + if not self._validate_equation(equation, numbers): + return self.format_score + + try: + result = self._evaluate_equation(equation) + if result is None: + return self.format_score + + if abs(result - target) < 1e-5: + return self.correct_score + else: + return self.format_score + except Exception: + return self.format_score + + def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple: + raw_reward = 0 + final_answer = workflow_output.metadata["final_answer"] + target = workflow_output.metadata["target"] + numbers = workflow_output.metadata["nums"] + + if target is None or not numbers: + return 0.0, False + + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + + if match: + result = match.group(1) + raw_reward = self._compute_score(result, target, numbers) + is_success = raw_reward >= self.correct_score + else: + raw_reward = 0.0 + is_success = False + + return raw_reward, is_success diff --git a/tutorial/example_countdown/prepare_data.py b/tutorial/example_countdown/prepare_data.py new file mode 100644 index 00000000..4dba5d76 --- /dev/null +++ b/tutorial/example_countdown/prepare_data.py @@ -0,0 +1,148 @@ +import argparse +import glob +import os + +from beast_logger import print_list +from datasets import DatasetDict, load_dataset + +parser = argparse.ArgumentParser(description="download Hugging Face dataset") +parser.add_argument("--target", default="Jiayi-Pan/Countdown-Tasks-3to4", type=str, help="HuggingFace dataset name") +parser.add_argument( + "--path", + default="./dataset", + type=str, + help="Path to the local directory where the dataset will be downloaded", +) +args = parser.parse_args() + +# Don't download to local, just load directly from HuggingFace +print(f"Loading dataset from {args.target}...") + + +def display_dataset(dataset_name, dataset_iter, header): + from beast_logger import print_listofdict + + data = [] + for sample in dataset_iter: + s = dict(sample) + data.append(s) + print_listofdict(data[:5], header=header) + + +try: + import datasets + + # Load the original dataset directly from HuggingFace (no local download) + print("\nLoading original dataset from HuggingFace...") + original_dataset = load_dataset(args.target, split="train") + + print(f"\nOriginal dataset size: {len(original_dataset)}") + + # Print dataset schema (column names and types) + print("\n" + "=" * 80) + print("Dataset Schema:") + print("=" * 80) + for col_name in original_dataset.column_names: + col_type = original_dataset.features[col_name] + print(f"Column: {col_name:20s} | Type: {col_type}") + + # Print examples for each column + print("\n" + "=" * 80) + print("Sample Data (First 3 Examples):") + print("=" * 80) + for i in range(min(3, len(original_dataset))): + print(f"\nExample {i+1}:") + for col_name in original_dataset.column_names: + value = original_dataset[i][col_name] + # Truncate long strings for display + if isinstance(value, str) and len(value) > 100: + value = value[:100] + "..." + print(f" {col_name}: {value}") + + # Split dataset: 1024 examples for test, 10x (10240) for training + test_size = 1024 + train_size = test_size * 10 + total_size = len(original_dataset) + + # Ensure we have enough data + if total_size < test_size + train_size: + print(f"\nWarning: Dataset size ({total_size}) is smaller than required ({test_size + train_size})") + print("Adjusting sizes proportionally...") + test_size = min(test_size, total_size // 11) + train_size = test_size * 10 + + print("\n" + "=" * 80) + print(f"Splitting dataset: {test_size} test samples, {train_size} train samples") + print("=" * 80) + + # Create train/test split + test_dataset = original_dataset.select(range(test_size)) + train_dataset = original_dataset.select(range(test_size, test_size + train_size)) + + # Create a DatasetDict with train and test splits + split_dataset = DatasetDict({"train": train_dataset, "test": test_dataset}) + + # Create output directory: dataset/Countdown-Tasks + output_data_dir = os.path.join(args.path, "Countdown-Tasks") + os.makedirs(output_data_dir, exist_ok=True) + + # Save as parquet files in the Countdown-Tasks directory + train_parquet_path = os.path.join(output_data_dir, "train-00000-of-00001.parquet") + test_parquet_path = os.path.join(output_data_dir, "test-00000-of-00001.parquet") + + print("\nSaving split datasets...") + train_dataset.to_parquet(train_parquet_path) + test_dataset.to_parquet(test_parquet_path) + + print(f"✓ Saved training set to: {train_parquet_path}") + print(f"✓ Saved test set to: {test_parquet_path}") + + # Display statistics + print("\n" + "=" * 80) + print("Dataset Statistics:") + print("=" * 80) + print(f"Training set size: {len(train_dataset)}") + print(f"Test set size: {len(test_dataset)}") + print(f"Train/Test ratio: {len(train_dataset) / len(test_dataset):.1f}x") + print(f"Total used: {len(train_dataset) + len(test_dataset)}") + print(f"Original dataset size: {len(original_dataset)}") + + # Display sample data from train and test sets + print("\n") + display_dataset(args.target, train_dataset, header="train (first 5 samples)") + print("\n") + display_dataset(args.target, test_dataset, header="test (first 5 samples)") + + # Verify the split files can be loaded + print("\n" + "=" * 80) + print("Verifying split files...") + print("=" * 80) + train_loaded = datasets.load_dataset("parquet", data_files=train_parquet_path, split="train") + test_loaded = datasets.load_dataset("parquet", data_files=test_parquet_path, split="train") + print(f"✓ Train parquet loaded successfully: {len(train_loaded)} samples") + print(f"✓ Test parquet loaded successfully: {len(test_loaded)} samples") + + # List saved files + print("\n" + "=" * 80) + print("Saved Files:") + print("=" * 80) + saved_files = [] + for item in glob.glob(os.path.join(output_data_dir, "*"), recursive=False): + if os.path.isfile(item): + saved_files.append(os.path.abspath(item)) + print_list(saved_files, header="saved files") + + # Final file structure + print("\n" + "=" * 80) + print("Final Directory Structure:") + print("=" * 80) + print(f"{args.path}/") + print("└── Countdown-Tasks/") + print(f" ├── train-00000-of-00001.parquet ({len(train_dataset)} samples)") + print(f" └── test-00000-of-00001.parquet ({len(test_dataset)} samples)") + +except Exception as e: + print(f"Error loading dataset {args.target}: {e}") + import traceback + + traceback.print_exc() diff --git a/tutorial/example_feedback_tracing/README.md b/tutorial/example_feedback_tracing/README.md new file mode 100644 index 00000000..8ed775c7 --- /dev/null +++ b/tutorial/example_feedback_tracing/README.md @@ -0,0 +1,46 @@ +# Training a New Agent from Tracing Logs + +AgentJet allows you to recycle the chat logs generated during an Agent’s execution and continuously improve the Agent through iterative training. +This document demonstrates how to train an Agent using tracing log feedback. + +## 1. Preparing the Data + +To use tracing logs for training, you must already have an Agent built with **agentscope** running in **agentscope-studio** for some time. + +In this example, we implement a math-problem-solving agent in `agent_deployed.py`. +To demonstrate the workflow, we will first simulate the data-collection process. + +1. Install [agentscope-studio](https://github.com/agentscope-ai/agentscope-studio). +2. Start agentscope-studio with the default port settings. +3. Run `agent_deployed.py` and simulate user–agent conversations. + +After several rounds of interaction, studio will store the tracing logs in +`~/AgentScope-Studio/database.sqlite`, containing all recorded dialogues between the user and the agent. + +> **AgentScope & Studio Version Compatibility** +> +> It is recommended to use matched versions: +> +> * AgentScope (v1.0.7) +> * Studio (23eb7c0b1185486d1baca36aea0ce8b85ea9de48) + +## 2. Starting Trace-Feedback Training + +Once you have the tracing log (`database.sqlite`), you can use the trace-feedback training module to train a new Agent. + +1. Set the `task_reader` parameter to `tracing` in the configuration file to enable trace-feedback mode. +2. Configure the `tracing` section with the database path and filtering options. +3. Configure other training parameters and Rewards as you would in a normal training workflow. + +An example database and configuration file are provided under +`example_feedback_tracing/`. + +When everything is ready, start the training with: + +```bash +ajet --conf tutorial/example_feedback_tracing/example_feedback_tracing.yaml --backbone='trinity' --with-ray +``` + +## 3. Deploying the New Agent + +You can now deploy the newly trained Agent into production, enabling continuous improvement through iterative trace-feedback training. diff --git a/tutorial/example_feedback_tracing/agent_deployed.py b/tutorial/example_feedback_tracing/agent_deployed.py new file mode 100644 index 00000000..05b0bbcd --- /dev/null +++ b/tutorial/example_feedback_tracing/agent_deployed.py @@ -0,0 +1,54 @@ +import os + +import agentscope +from agentscope.agent import ReActAgent +from agentscope.formatter import DashScopeChatFormatter +from agentscope.memory import InMemoryMemory +from agentscope.message import Msg +from agentscope.model import DashScopeChatModel +from agentscope.tool import Toolkit, execute_python_code + +SYSTEM_PROMPT = """ +You are an agent specialized in solving math problems with tools. +If I give problem, please solve the math problem given to you. +You can write and execute Python code to perform calculation or verify your answer. +You should return your final answer within \\boxed{{}}. +""" + + +def build_agent(): + tool_kit = Toolkit() + tool_kit.register_tool_function(execute_python_code) + + agent = ReActAgent( + name="Qwen", + sys_prompt=SYSTEM_PROMPT, + model=DashScopeChatModel( + model_name="qwen-max", + api_key=os.environ["DASHSCOPE_API_KEY"], + stream=True, + ), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=tool_kit, + print_hint_msg=False, + ) + + return agent + + +async def main(): + # init the tracing module + agentscope.init(studio_url="http://localhost:3000") + + agent = build_agent() + + while True: + inp = input("User: ") + print(await agent.reply(Msg("user", inp, role="user"))) + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/tutorial/example_feedback_tracing/database.sqlite b/tutorial/example_feedback_tracing/database.sqlite new file mode 100644 index 00000000..670372e2 Binary files /dev/null and b/tutorial/example_feedback_tracing/database.sqlite differ diff --git a/tutorial/example_feedback_tracing/example_feedback_tracing.yaml b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml new file mode 100644 index 00000000..1cb01333 --- /dev/null +++ b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml @@ -0,0 +1,77 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_feedback_training + task_reader: + type: tracing + feedback_tracing: + base_url: ./tutorial/example_feedback_tracing/database.sqlite + train_output_path: ./tutorial/example_feedback_tracing/tasks.jsonl + alien_llm_model: qwen3-235b-a22b-instruct-2507 + alien_llm_response_length: 2048 + filters: + - type: llm_evaluate + enabled: true + params: + custom_rubrics: | + 1. Check the answer and drop the task if it does not answer or answer is wrong. + 2. Consider a response is invalid if it does not wrap the final answer in \boxed{}. + temperature: 0.5 + print_reason: false + max_thread: 16 + - type: deduplication + enabled: true + params: + similarity_threshold: 0.8 + db_path: ./.similarity_db + model: text-embedding-v4 + api_key: null # load from the env + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + user_workflow: tutorial.example_feedback_tracing.train->ExampleTracingFeedbackTrain + temperature: 0.7 + max_env_worker: 256 + num_repeat: 8 + agent_madness_reward: -1.0 + tensor_model_parallel_size: 4 + max_num_seqs: 256 + multi_turn: + max_sample_per_task: 4 + compute_madness_checklist: + - "nonsense" + - "wrong_toolcall" + data: + train_batch_size: 240 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + +# remove swanlab logger +trainer: + logger: + - console + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_feedback_tracing/train.py b/tutorial/example_feedback_tracing/train.py new file mode 100644 index 00000000..858bae6b --- /dev/null +++ b/tutorial/example_feedback_tracing/train.py @@ -0,0 +1,60 @@ +from agentscope.agent import ReActAgent +from agentscope.formatter import DashScopeChatFormatter +from agentscope.memory import InMemoryMemory +from agentscope.message import Msg +from agentscope.tool import Toolkit, execute_python_code +from loguru import logger +from pydantic import BaseModel, Field + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + +SYSTEM_PROMPT = """ +You are an agent specialized in solving math problems with tools. +If I give problem, please solve the math problem given to you. +You can write and execute Python code to perform calculation or verify your answer. +You should return your final answer within \\boxed{{}}. +""" + + +def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if hasattr(result, "metadata") and isinstance(result.metadata, dict) and "result" in result.metadata: + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + +class FinalResult(BaseModel): + result: str = Field(description="Your solution of the given math problem. Put your final answer in boxed format, e.g., \\boxed{42}") + + +class ExampleTracingFeedbackTrain(Workflow): + name: str = "tracing_feedback_train" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + query = workflow_task.task.main_query + + tool_kit = Toolkit() + tool_kit.register_tool_function(execute_python_code) + + agent = ReActAgent( + name="Qwen", + sys_prompt=SYSTEM_PROMPT, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + memory=InMemoryMemory(), + toolkit=tool_kit, + print_hint_msg=False, + ) + + msg = Msg("user", query, role="user") + result = await agent.reply(msg, structured_model=FinalResult) + final_answer = extract_final_answer(result) + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) diff --git a/tutorial/example_frozenlake/frozenlake.md b/tutorial/example_frozenlake/frozenlake.md new file mode 100644 index 00000000..8ddbabc6 --- /dev/null +++ b/tutorial/example_frozenlake/frozenlake.md @@ -0,0 +1,16 @@ +# Frozen Lake + +This example shows the usage of GRPO on the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task. + + +## Data and Environment Preparation + +After setting up the basic environment following the [installation guidance](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html), you need to install the additional dependencies by running the following command: + +```bash +pip install gymnasium[toy_text] +``` + +## begin training + +python launcher.py --conf tutorial/example_frozenlake/frozenlake.yaml --backbone=verl diff --git a/tutorial/example_frozenlake/frozenlake.py b/tutorial/example_frozenlake/frozenlake.py new file mode 100644 index 00000000..2bdd3bde --- /dev/null +++ b/tutorial/example_frozenlake/frozenlake.py @@ -0,0 +1,451 @@ +# -*- coding: utf-8 -*- +""" +This file defines a multi-step workflow for the FrozenLake environment. +Modified from https://github.com/rllm-org/rllm/blob/main/rllm/environments/frozenlake/frozenlake.py +""" + +from __future__ import annotations + +import copy +import re +import traceback +from typing import Dict, Optional, Tuple + +import numpy as np +from agentscope.agent import ReActAgent +from agentscope.formatter import DashScopeChatFormatter +from agentscope.model import DashScopeChatModel +from agentscope.message import Msg +from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv as GymFrozenLakeEnv +from loguru import logger + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + +SYSTEM_PROMPT = """You are a helpful assistant. You are walking on a frozen lake. + +FrozenLake Quick Guide +Goal: Reach the goal (G). Player (P) and Goal (G) must overlap. + +Symbols: +_ Frozen | O Hole | G Goal | P Player + +Rules: +1. Avoid falling into holes (O). +2. Frozen tiles are slippery, you may move perpendicular to your intended direction. + +Valid Action (separated by | ): +Up | Down | Left | Right + +Rewards: +Fall into hole: 0 +Reach goal: +1.0 + +You will be provided the current observation, please decide on the next Action. +You should show your short thought process and then input the final action in ``` ```. +You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```. +You should plan ahead and need to achieve it in minimum number of steps. +You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it. + +Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right. +""" + + +class FrozenLakeWorkflow(Workflow): + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + config = tuner.config + + self.env_max_steps = config.ajet.rollout.multi_turn.max_steps + self.agent_max_steps = config.ajet.rollout.multi_turn.max_steps + + # Extract task-specific arguments + self.raw_task = workflow_task.task.metadata + + self.size = config.frozen_lake.frozen_lake_size + self.seed = workflow_task.task.metadata["random_number"] + self.p = 0.8 # Probability that a tile is frozen + + # Agent-related state + self.step_count: int = 0 + + # init agent and environment + self.agent = FrozenLakeAgent( + model=tuner.as_agentscope_model(), + max_steps=self.agent_max_steps, + ) + self.env = FrozenLakeEnv( + max_steps=self.env_max_steps, + is_slippery=config.frozen_lake.is_slippery, + size=self.size, + p=self.p, # Probability that a tile is frozen + seed=self.seed, + ) + + return await self.run_frozenlake() + + async def run_frozenlake(self): + self.env.reset(self.raw_task) + terminate_reason = None + observation_str = str(self.env.render()) + rewards = [] + step_count = 0 + done = False + for _ in range(self.agent_max_steps): + step_count += 1 + try: + action = await self.agent.step(current_observation=observation_str) + except Exception: + logger.error(f"Agent failed to produce action due to error:\n{traceback.format_exc()}") + terminate_reason = "agent_error" + break + observation, reward, done, _ = self.env.step(action) + observation_str = str(observation) + rewards.append(reward) + if done: + terminate_reason = "success" + break + + if terminate_reason is None: + terminate_reason = "max_steps_reached" + + final_reward = sum(rewards) + return WorkflowOutput( + reward=final_reward, + metadata={ + "terminate_reason": terminate_reason, + "step_count": step_count, + }, + ) + + +class FrozenLakeAgent: + INVALID_ACTION = "still" + + def __init__(self, model: DashScopeChatModel, max_steps: int = 20): + self.agent = ReActAgent( + name="frozenlake_agent", + sys_prompt=SYSTEM_PROMPT, + model=model, + formatter=DashScopeChatFormatter(), + max_iters=2, + ) + self.agent.set_console_output_enabled(False) + self.current_step = 0 + self.last_action = None + self.last_observation = None + self.max_steps = max_steps + + def get_prompt(self, observation: str) -> str: + prompt = f"Current Observation ({self.current_step}): \n" + observation + "\n" + "You have not achieved the goal, P has not reached G yet. Please give the next action." + if self.current_step > 0 and self.last_action is not None: + if self.last_observation == observation: + prompt += "\nYour last response is invalid. Your position didn't change at all. You may need to recheck your thinking process, action outputted, and the format of response. Remember, you should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```." + + if self.max_steps is not None and self.max_steps - self.current_step > 0: + prompt += f"\nThe maximum number of steps remaining is {self.max_steps - self.current_step}." + + return prompt + + def get_action(self, msg: Msg) -> str: + logger.info(f"Agent response: {msg.content}") + response: str = msg.content if isinstance(msg.content, str) else msg.content[0].get("text") + action = self.INVALID_ACTION + + matches = re.findall(r"```(.*?)```", response, re.DOTALL) + + if matches: + last_match_content = matches[-1].strip() + action = last_match_content.lower() + if action not in ["up", "down", "left", "right"]: + action = self.INVALID_ACTION + + return action + + async def step(self, current_observation: str) -> str: + prompt = self.get_prompt(current_observation) + msg = await self.agent.reply(Msg("user", prompt, role="user")) + action = self.get_action(msg) + self.last_observation = current_observation + self.last_action = action + self.current_step += 1 + return action + + +class FrozenLakeEnv(GymFrozenLakeEnv): + # Map gym state in integer + MAP_LOOKUP = { + b"P": 0, + b"F": 1, + b"H": 2, + b"G": 3, + } + + # Define rules to transform to rendered text observation of the environment + GRID_LOOKUP = { + 0: " P \t", # player + 1: " _ \t", # frozen + 2: " O \t", # hole + 3: " G \t", # goal + 4: " X \t", # player fall into hole + 5: " √ \t", # player on goal + } + + ACTION_LOOKUP = { + "still": 0, + "left": 1, + "down": 2, + "right": 3, + "up": 4, + } + + INVALID_ACTION = 0 + PENALTY_FOR_INVALID = -1 + + def __init__( + self, + max_steps: int = 8, + is_slippery: bool = False, + size: int = 8, + p: float = 0.8, + seed: int = 42, + ): + self.max_steps = max_steps or 8 + self.is_slippery = is_slippery + self.size = size + self.p = p + self.seed = seed + try: + import gymnasium as gym + from gymnasium.envs.toy_text.frozen_lake import ( + FrozenLakeEnv as GymFrozenLakeEnv, + ) + except ImportError as e: + error_message = f"Gymnasium is not installed. Please install gymnasium first before " f"running the frozen_lake workflow. Error: {str(e)}" + logger.error(error_message) + raise ImportError(error_message) + + random_map, goal_position = generate_random_map(size=self.size, p=self.p, seed=self.seed, max_steps=self.max_steps) + + self.goal_position = goal_position + + GymFrozenLakeEnv.__init__(self, desc=random_map[:], is_slippery=self.is_slippery) + self.action_space = gym.spaces.Discrete(4, start=1) + + self.map_kwargs = { + "size": size, + "p": p, + } + self.env_kwargs = { + "is_slippery": is_slippery, + "desc": None, + "seed": seed, + } + + self.action_map = { + 1: 0, # left + 2: 1, # down + 3: 2, # right + 4: 3, # up + } + + def _get_player_position(self) -> Tuple[int, int]: + return (self.s // self.ncol, self.s % self.ncol) # (row, col) + + def step(self, action: str) -> Tuple[str, float, bool, Dict]: + """Execute a step in the environment. + + Maps custom action to gymnasium FrozenLakeEnv action and takes the step. + Checks if the action is effective (whether player moves in the env). + + Args: + action: The action to take. + + Returns: + Tuple of (observation, reward, done, info). + """ + if self.success(): + return self.render(), 1, True, {"action_is_effective": False} + + action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0) + + if not action_id: + action_id = self.INVALID_ACTION + + if action_id == self.INVALID_ACTION or action_id not in self.action_map: + return self.render(), 0, False, {"action_is_effective": False} + + prev_player_position = int(self.s) + + player_pos, reward, done, _, _ = GymFrozenLakeEnv.step(self, self.action_map[action_id]) + + obs = self.render() + return obs, reward, done, {"action_is_effective": prev_player_position != int(player_pos)} + + def render(self, mode="tiny_rgb_array"): + """Render the environment. + + Args: + mode: Rendering mode. Options: "tiny_rgb_array", "list", "state", "rgb_array", "ansi". + + Returns: + Rendered observation based on the mode. + """ + assert mode in ["tiny_rgb_array", "list", "state", "rgb_array", "ansi"] + if mode in ["rgb_array", "ansi"]: + prev_render_mode = self.render_mode + self.render_mode = mode + obs = GymFrozenLakeEnv.render(self) + self.render_mode = prev_render_mode + return obs + room_state = copy.deepcopy(self.desc) + + # replace the position of start 'S' with 'F' + position_S = np.where(room_state == b"S") + room_state[position_S] = b"F" + + # replace the position of the player with 'P' + position_P = self._get_player_position() + room_state[position_P] = b"P" + + if mode == "state": + # transform 'S', 'F', 'H', 'G' to numpy integer array + room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state) + # add player in hole or player on goal + if self.desc[position_P] == b"H": + room_state[position_P] = 4 + elif self.desc[position_P] == b"G": + room_state[position_P] = 5 + return room_state + + room_state = self.render(mode="state").tolist() + + if mode == "list": + + def lookup(cell): + return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip() + + return [" ".join(lookup(cell) for cell in row) for row in room_state] + + if mode == "tiny_rgb_array": + + def lookup(cell): + return self.GRID_LOOKUP.get(cell, "?") + + result = "\n".join("".join(lookup(cell) for cell in row) for row in room_state) + return result + + def reset(self, task: Optional[Dict] = None): + task = task or {} + self.__init__( # type: ignore [misc] + size=task.get("size", self.map_kwargs["size"]), + p=task.get("p", self.map_kwargs["p"]), + seed=task.get("seed", self.env_kwargs["seed"]), + is_slippery=task.get("is_slippery", self.env_kwargs["is_slippery"]), + ) + GymFrozenLakeEnv.reset(self, seed=self.seed) + return self.render(mode="tiny_rgb_array"), {} + + def finished(self) -> bool: + player_pos = self._get_player_position() + return self.desc[player_pos] in b"GH" # type: ignore [index,operator] + + def success(self): + """ + Check if the agent has reached the goal (G). + """ + player_pos = self._get_player_position() + return self.desc[player_pos] in b"G" + + +def is_valid(board: list[list[str]], max_size: int, max_steps: int) -> bool: + """DFS to check that it's a valid path. + + Args: + board: The board representation as a list of lists. + max_size: Maximum size of the board. + max_steps: Maximum number of steps allowed. + + Returns: + True if there's a valid path from start to goal within max_steps, False otherwise. + """ + frontier, discovered = [], set() + # find the start point + start_r, start_c = np.where(np.array(board) == "S") + frontier.append((start_r[0], start_c[0], 0)) # row, col steps + # dfs to check if there is a path from start to goal + while frontier: + r, c, steps = frontier.pop() + if steps > max_steps: + continue + + if (r, c) not in discovered: + discovered.add((r, c)) + directions = [(1, 0), (0, 1), (-1, 0), (0, -1)] + for x, y in directions: + r_new = r + x + c_new = c + y + if r_new < 0 or r_new >= max_size or c_new < 0 or c_new >= max_size: + continue + if board[r_new][c_new] == "G": + return True + if board[r_new][c_new] != "H": + frontier.append((r_new, c_new, steps + 1)) + return False + + +def generate_random_map(size: int = 8, p: float = 0.8, seed: int = 0, max_steps: int = 5) -> Tuple[list[str], Tuple[int, int]]: + """Generates a random valid map (one that has a path from start to goal). + + Args: + size: Size of each side of the grid. + p: Probability that a tile is frozen. + seed: Seed to ensure the generation of reproducible maps. + max_steps: Maximum number of steps allowed. + + Returns: + A tuple containing a random valid map and the goal position (row, col). + """ + valid = False + board: list[list[str]] = [] # initialize to make pyright happy + + try: + from gymnasium.utils import seeding + + np_random, _ = seeding.np_random(seed) + except ImportError: + raise ImportError("Gymnasium is not installed. Please install gymnasium first before running the frozen_lake workflow.") + + # generate random start and end points + while not valid: + p = min(1, p) + board = np_random.choice(["F", "H"], (size, size), p=[p, 1 - p]).tolist() + + while True: + start_r = int(np_random.integers(0, size)) + start_c = int(np_random.integers(0, size)) + goal_r = int(np_random.integers(0, size)) + goal_c = int(np_random.integers(0, size)) + + # Ensure start and goal are different positions + if (start_r, start_c) != (goal_r, goal_c): + break + + board[start_r][start_c] = "S" + board[goal_r][goal_c] = "G" + + valid = is_valid(board, size, max_steps) + return ["".join(x) for x in board], (goal_r, goal_c) + + +def get_goal_position(random_map: np.ndarray) -> Optional[Tuple[int, int]]: + """Get the goal position from a random map. + + Args: + random_map: The map as a numpy array. + + Returns: + Tuple of (row, col) if goal found, None otherwise. + """ + positions = np.argwhere(random_map == b"G") + if positions.size == 0: + return None # G not found + return tuple(positions[0]) # returns (row, col) diff --git a/tutorial/example_frozenlake/frozenlake_easy.yaml b/tutorial/example_frozenlake/frozenlake_easy.yaml new file mode 100644 index 00000000..d6ba032a --- /dev/null +++ b/tutorial/example_frozenlake/frozenlake_easy.yaml @@ -0,0 +1,91 @@ +# ------------------ main config ------------------ +ajet: + project_name: example_frozenlake + task_reader: + type: random_dummy # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` or `random_dummy` + + task_judge: + # ✨✨✨✨ select evaluation function + judge_protocol: null + + model: + # ✨✨✨✨ select model to be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct + + + rollout: + # the path to the workflow class + user_workflow: tutorial.example_frozenlake.frozenlake->FrozenLakeWorkflow # ✨✨✨✨ select agent + force_disable_toolcalls: True + + temperature: 0.9 + + max_env_worker: 512 + + num_repeat: 4 + + agent_madness_reward: 0.0 + + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 4 + + mode: async + + max_num_seqs: 40 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: [] + + max_response_length_in_one_turn: 512 + + max_model_len: 22000 + + multi_turn: + # how many samples should be collected for each task run + max_sample_per_task: 20 + # limit the maximum steps for each task + max_steps: 20 + # the expected steps for each task, used to calculate the training batch size for trinity + expected_steps: 1 + + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 128 + max_prompt_length: 4000 + max_response_length: 18000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + total_training_steps: 25 + nnodes: 1 + n_gpus_per_node: 8 + logger: swanlab + + +frozen_lake: + frozen_lake_size: 4 + is_slippery: False + + + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ do not edit ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_frozenlake/frozenlake_hard.yaml b/tutorial/example_frozenlake/frozenlake_hard.yaml new file mode 100644 index 00000000..beb23df1 --- /dev/null +++ b/tutorial/example_frozenlake/frozenlake_hard.yaml @@ -0,0 +1,89 @@ +# ------------------ main config ------------------ +ajet: + project_name: example_frozenlake_7b + task_reader: + type: random_dummy # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` or `random_dummy` + + task_judge: + # ✨✨✨✨ select evaluation function + judge_protocol: null + + model: + # ✨✨✨✨ select model to be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + + rollout: + # the path to the workflow class + user_workflow: tutorial.example_frozenlake.frozenlake->FrozenLakeWorkflow # ✨✨✨✨ select agent + force_disable_toolcalls: True + + temperature: 0.9 + + max_env_worker: 512 + + num_repeat: 4 + + agent_madness_reward: 0.0 + + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 4 + + mode: async + + max_num_seqs: 40 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: [] + + max_response_length_in_one_turn: 512 + + max_model_len: 22000 + + multi_turn: + # how many samples should be collected for each task run + max_sample_per_task: 20 + # limit the maximum steps for each task + max_steps: 20 + # the expected steps for each task, used to calculate the training batch size for trinity + expected_steps: 1 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 128 + max_prompt_length: 4000 + max_response_length: 18000 + + trainer_common: + save_freq: 99999 + test_freq: 99999 + total_epochs: 99999 + total_training_steps: 25 + nnodes: 1 + n_gpus_per_node: 8 + logger: swanlab + +frozen_lake: + frozen_lake_size: 6 + is_slippery: True + + + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ do not edit ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_learn2ask/data_preprocess/download_processed.sh b/tutorial/example_learn2ask/data_preprocess/download_processed.sh new file mode 100755 index 00000000..3359580b --- /dev/null +++ b/tutorial/example_learn2ask/data_preprocess/download_processed.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" +URL="https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/astuner_archive/RealMedConv_processed.zip" +TMP_DIR="/tmp/learn2ask" + +TARGET_DIR="${1:-${REPO_ROOT}/data/realmedconv}" +if [[ "${TARGET_DIR}" != /* ]]; then + TARGET_DIR="${REPO_ROOT}/${TARGET_DIR}" +fi + +cleanup() { + rm -rf "${TMP_DIR}" +} +trap cleanup EXIT + +rm -rf "${TMP_DIR}" +mkdir -p "${TMP_DIR}" "${TARGET_DIR}" + +wget -O "${TMP_DIR}/RealMedConv_processed.zip" "${URL}" +unzip -o "${TMP_DIR}/RealMedConv_processed.zip" -d "${TMP_DIR}" + +mv "${TMP_DIR}/test.jsonl" "${TARGET_DIR}/test.jsonl" +mv "${TMP_DIR}/train.jsonl" "${TARGET_DIR}/train.jsonl" + +echo "Saved processed dataset to ${TARGET_DIR}" diff --git a/tutorial/example_learn2ask/data_preprocess/llm_info_extraction.py b/tutorial/example_learn2ask/data_preprocess/llm_info_extraction.py new file mode 100644 index 00000000..7c36f416 --- /dev/null +++ b/tutorial/example_learn2ask/data_preprocess/llm_info_extraction.py @@ -0,0 +1,146 @@ +import os + +import openai +import torch +import transformers + +tokenizer = None +llm = None + + +def LLM_info_extraction(remaining_chat, model_call_mode, **kwargs): + """ + Extract information from remaining_chat using LLM. + + Args: + remaining_chat (str): The chat content to process + model_call_mode (str): Either "online_api" or "local_vllm" + **kwargs: Additional parameters for API calls + + Returns: + str: Response text from LLM or error information + """ + + # Create messages format with system and user roles + system_message = """ + # Task: + You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field. + # Requirements: + - Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output. + - Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown"). + - Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field. + - Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]". + - Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology. + # Example output format: + ['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day'] + """ + user_message = remaining_chat + + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": "```\n" + user_message + "\n```\n"}, + ] + + try: + if model_call_mode == "online_api": + # OpenAI-style API call + return _call_online_api(messages, **kwargs) + elif model_call_mode == "local_vllm": + # Local vLLM call + return _call_local_vllm(messages, **kwargs) + else: + return f"Error: Invalid model_call_mode '{model_call_mode}'. Must be 'online_api' or 'local_vllm'." + except Exception as e: + return f"Error occurred: {str(e)}" + + +def _call_online_api(messages, **kwargs): + """Handle OpenAI-style API calls""" + # Extract API parameters from kwargs or use defaults + api_key = kwargs.get("api_key", os.getenv("DASHSCOPE_API_KEY")) + api_base = kwargs.get("api_base", "https://dashscope.aliyuncs.com/compatible-mode/v1") + model = kwargs.get("model", "qwen-max") + temperature = kwargs.get("temperature", 0.7) + max_tokens = kwargs.get("max_tokens", 500) + + client = openai.OpenAI(api_key=api_key, base_url=api_base) + response = client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens) + + return response.choices[0].message.content + + +def _call_local_vllm(messages, **kwargs): + """Handle local vLLM calls""" + try: + from vllm import LLM, SamplingParams + + model_path = kwargs.get("model_path") + if not model_path: + return "Error: model_path is required for local vLLM inference" + + temperature = kwargs.get("temperature", 0.7) + max_tokens = kwargs.get("max_tokens", 512) + top_p = kwargs.get("top_p", 0.9) + repetition_penalty = kwargs.get("repetition_penalty", 1.1) + + # GPU/CUDA related parameters for vLLM + tensor_parallel_size = kwargs.get("tensor_parallel_size", torch.cuda.device_count()) + gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9) + enforce_eager = kwargs.get("enforce_eager", False) + dtype = kwargs.get("dtype", "auto") + max_model_len = kwargs.get("max_model_len", 4096) + + # Initialize the LLM with the provided model path and GPU parameters + global llm, tokenizer + if llm is None: + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + dtype=dtype, + max_model_len=max_model_len, + ) + + sampling_params = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + repetition_penalty=repetition_penalty, + ) + + # Convert messages to a single prompt string + if tokenizer is None: + tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + outputs = llm.generate([prompt], sampling_params) + + return outputs[0].outputs[0].text + + except ImportError: + return "Error: vLLM library not installed. Please install it with 'pip install vllm'" + except Exception as e: + return f"Error in local vLLM inference: {str(e)}" + + +def parse_llm_output(output_str): + """ + Convert the LLM info extraction output string to a list of strings. + + Args: + output_str (str): String in format "['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']" + + Returns: + list: List of strings if successful, error message string if failed + """ + import ast + + try: + result = ast.literal_eval(output_str) + if not isinstance(result, list): + return f"Error: Expected a list, got {type(result)}" + + return result + except Exception as e: + return f"Error parsing output: [{repr(output_str)}] error = {str(e)}" diff --git a/tutorial/example_learn2ask/data_preprocess/message_splitter.py b/tutorial/example_learn2ask/data_preprocess/message_splitter.py new file mode 100644 index 00000000..06362b05 --- /dev/null +++ b/tutorial/example_learn2ask/data_preprocess/message_splitter.py @@ -0,0 +1,100 @@ +import json + + +def split_single_message_list(messages): + """ + Split a single message list into multiple rounds. + + Args: + messages (list): List of message dictionaries with 'role' and 'content' keys + + Returns: + list: List of rounds, where each round contains messages and remaining chat + """ + rounds = [] + round_number = 1 + i = 0 + + while i < len(messages): + # Collect messages for this round + round_messages = [] + + # Add messages until we reach a user message + while i < len(messages) and messages[i].get("role") != "user": + round_messages.append(messages[i]) + i += 1 + + # Add user message(s) - if there are consecutive user messages, + # include all of them in this round + while i < len(messages) and messages[i].get("role") == "user": + round_messages.append(messages[i]) + i += 1 + + # The remaining messages (if any) form the remaining_chat + remaining_messages = messages[i:] + round_entry = {"round_number": round_number, "messages": round_messages} + + # Add remaining chat if there are remaining messages + if remaining_messages: + remaining_chat_parts = [] + for msg in remaining_messages: + role = msg.get("role", "") + content = msg.get("content", "") + remaining_chat_parts.append(f"{role}: {content}") + round_entry["remaining_chat"] = "\n".join(remaining_chat_parts) + else: + round_entry["remaining_chat"] = "" + + rounds.append(round_entry) + round_number += 1 + + return rounds + + +def split_session_to_json_lines(session): + """ + Split a session dictionary into multiple rounds and convert to JSON lines. + + Args: + session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys + - session_id (str): Session identifier + - diagn (str): Diagnosis information + - messages (list): List of message dictionaries with 'role' and 'content' keys + + Returns: + list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat + """ + rounds = split_single_message_list(session["messages"]) + + json_lines = [] + for round_data in rounds: + round_entry = { + "cid": f"{session['session_id']}_{round_data['round_number']}", + "session_id": session["session_id"], + "diagn": session["diagn"], + "messages": round_data["messages"], + "remaining_chat": round_data["remaining_chat"], + } + + json_lines.append(json.dumps(round_entry, ensure_ascii=False)) + + return json_lines + + +# Example usage: +if __name__ == "__main__": + # Example of splitting a single message list + example_messages = [ + {"role": "assistant", "content": "Hello, how can I help you today?"}, + {"role": "user", "content": "I've been having headaches lately."}, + {"role": "assistant", "content": "How long have you been experiencing these headaches?"}, + {"role": "user", "content": "For about a week now."}, + {"role": "assistant", "content": "I see. Have you taken any medication for them?"}, + {"role": "user", "content": "Yes, I've tried some over-the-counter pain relievers."}, + ] + + example_session = {"session_id": "session_1", "diagn": "migraine", "messages": example_messages} + json_lines = split_session_to_json_lines(example_session) + print("JSON lines output:") + for i, line in enumerate(json_lines): + print(f"Line {i + 1}: {line}") diff --git a/tutorial/example_learn2ask/data_preprocess/run_process.sh b/tutorial/example_learn2ask/data_preprocess/run_process.sh new file mode 100755 index 00000000..774f8bca --- /dev/null +++ b/tutorial/example_learn2ask/data_preprocess/run_process.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# export DASHSCOPE_API_KEY=your_api_key + +# python tutorial/example_learn2ask/data_preprocess/step1.py --input_file data/realmedconv/train_original.jsonl --output_file data/realmedconv/train_processed.jsonl +# python tutorial/example_learn2ask/data_preprocess/step2.py --input_file data/realmedconv/train_processed.jsonl --output_file data/realmedconv/train.jsonl + +# python tutorial/example_learn2ask/data_preprocess/step1.py --input_file data/realmedconv/test_original.jsonl --output_file data/realmedconv/test_processed.jsonl +# python tutorial/example_learn2ask/data_preprocess/step2.py --input_file data/realmedconv/test_processed.jsonl --output_file data/realmedconv/test.jsonl + + +set -euo pipefail + +DATA_DIR="${1:-}" +if [[ -z "${DATA_DIR}" ]]; then + echo "Usage: $0 " >&2 + exit 2 +fi + +if [[ ! -d "${DATA_DIR}" ]]; then + echo "Error: data_dir is not a directory: ${DATA_DIR}" >&2 + exit 2 +fi + +if [[ -z "${DASHSCOPE_API_KEY:-}" ]]; then + echo "Error: DASHSCOPE_API_KEY is not set. Please run: export DASHSCOPE_API_KEY=your_api_key" >&2 + exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)" + +TRAIN_ORIG="${DATA_DIR%/}/train_origin.jsonl" +TRAIN_PROC="${DATA_DIR%/}/train_processed.jsonl" +TRAIN_OUT="${DATA_DIR%/}/train.jsonl" + +TEST_ORIG="${DATA_DIR%/}/test_origin.jsonl" +TEST_PROC="${DATA_DIR%/}/test_processed.jsonl" +TEST_OUT="${DATA_DIR%/}/test.jsonl" + +if [[ ! -f "${TRAIN_ORIG}" ]]; then + echo "Error: missing file: ${TRAIN_ORIG}" >&2 + exit 2 +fi + +if [[ ! -f "${TEST_ORIG}" ]]; then + echo "Error: missing file: ${TEST_ORIG}" >&2 + exit 2 +fi + +python "${REPO_ROOT}/tutorial/example_learn2ask/data_preprocess/step1.py" --input_file "${TRAIN_ORIG}" --output_file "${TRAIN_PROC}" +python "${REPO_ROOT}/tutorial/example_learn2ask/data_preprocess/step2.py" --input_file "${TRAIN_PROC}" --output_file "${TRAIN_OUT}" + +python "${REPO_ROOT}/tutorial/example_learn2ask/data_preprocess/step1.py" --input_file "${TEST_ORIG}" --output_file "${TEST_PROC}" +python "${REPO_ROOT}/tutorial/example_learn2ask/data_preprocess/step2.py" --input_file "${TEST_PROC}" --output_file "${TEST_OUT}" + +echo "Done. Generated:" +echo "- ${TRAIN_OUT}" +echo "- ${TEST_OUT}" diff --git a/tutorial/example_learn2ask/data_preprocess/step1.py b/tutorial/example_learn2ask/data_preprocess/step1.py new file mode 100644 index 00000000..610324fa --- /dev/null +++ b/tutorial/example_learn2ask/data_preprocess/step1.py @@ -0,0 +1,197 @@ +import argparse +import json +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +from llm_info_extraction import LLM_info_extraction, parse_llm_output +from message_splitter import split_session_to_json_lines + + +def process_jsonl_file(input_file, output_file, model_call_mode="online_api", max_retries=3, max_workers=16, **kwargs): + """ + Process all sessions in a JSONL file and save results to output file using multi-threading. + Supports resuming from previous progress if interrupted. + + Args: + input_file (str): Path to input JSONL file + output_file (str): Path to output JSONL file + model_call_mode (str): Either "online_api" or "local_vllm" + max_retries (int): Maximum number of retries for LLM calls + max_workers (int): Maximum number of threads for parallel processing + **kwargs: Additional parameters for API calls + + Returns: + str: Success message or error information + """ + progress_file = output_file + ".progress" + + def load_progress(): + """Load progress from progress file. Returns set of completed line numbers.""" + if os.path.exists(progress_file): + with open(progress_file, "r", encoding="utf-8") as f: + return set(int(line.strip()) for line in f if line.strip()) + return set() + + def process_single_session(args): + """Worker function to process a single session.""" + line_num, line = args + if not line.strip(): + return line_num, None, None + try: + session = json.loads(line) + print(f"Processing session {session.get('session_id', 'unknown')} (line {line_num})...") + processed_lines = process_session(session, model_call_mode, max_retries, **kwargs) + return line_num, processed_lines, None + except json.JSONDecodeError as e: + return line_num, None, f"Warning: Skipping invalid JSON at line {line_num}: {e}" + except Exception as e: + return line_num, None, f"Warning: Error processing session at line {line_num}: {e}" + + try: + # Load previous progress + completed_lines = load_progress() + if completed_lines: + print(f"Resuming from previous progress. {len(completed_lines)} lines already completed.") + + # Read all lines first + with open(input_file, "r", encoding="utf-8") as infile: + all_lines = list(enumerate(infile, 1)) + + total_lines = len(all_lines) + # Filter out already completed lines + lines_to_process = [(num, line) for num, line in all_lines if num not in completed_lines] + + if not lines_to_process: + print("All lines already processed.") + # Clean up progress file + if os.path.exists(progress_file): + os.remove(progress_file) + return f"All lines already processed. Results in {output_file}" + + print(f"Processing {len(lines_to_process)} remaining lines out of {total_lines} total.") + + # State for ordered writing + results_buffer = {} # line_num -> processed_lines + next_line_to_write = min(num for num, _ in lines_to_process) + write_lock = threading.Lock() + progress_lock = threading.Lock() + + # Open output file in append mode if resuming, otherwise write mode + file_mode = "a" if completed_lines else "w" + outfile = open(output_file, file_mode, encoding="utf-8") + progress_out = open(progress_file, "a", encoding="utf-8") + + def flush_buffer(): + """Write all consecutive completed results from buffer to file.""" + nonlocal next_line_to_write + while next_line_to_write in results_buffer: + processed_lines = results_buffer.pop(next_line_to_write) + if processed_lines: + for processed_line in processed_lines: + outfile.write(processed_line + "\n") + outfile.flush() + # Save progress + with progress_lock: + progress_out.write(f"{next_line_to_write}\n") + progress_out.flush() + next_line_to_write += 1 + # Skip lines that were already completed or empty + while next_line_to_write <= total_lines and next_line_to_write not in dict(lines_to_process): + next_line_to_write += 1 + + try: + # Process sessions in parallel using ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(process_single_session, item): item[0] for item in lines_to_process} + + for future in as_completed(futures): + line_num, processed_lines, error = future.result() + if error: + print(error) + + with write_lock: + results_buffer[line_num] = processed_lines + flush_buffer() + finally: + outfile.close() + progress_out.close() + + # Clean up progress file on successful completion + if os.path.exists(progress_file): + os.remove(progress_file) + + return f"Successfully processed. Results saved to {output_file}" + + except Exception as e: + return f"Error processing JSONL file: {str(e)}" + + +def process_session(session, model_call_mode="online_api", max_retries=3, **kwargs): + """ + Pipeline function that splits messages into rounds and extracts info from each round's remaining chat. + + Args: + session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys + model_call_mode (str): Either "online_api" or "local_vllm" + max_retries (int): Maximum number of retries for LLM calls + **kwargs: Additional parameters for API calls + + Returns: + list: List of JSON strings with added "info_set" key, or error information + """ + # Step 1: Split messages into JSON lines + json_lines = split_session_to_json_lines(session) + + # Step 2: Process each JSON line with LLM info extraction + processed_lines = [] + + for line in json_lines: + data = json.loads(line) + remaining_chat = data.get("remaining_chat", "") + + # Retry loop for LLM calls + info_set = None + for attempt in range(max_retries): + try: + # Call LLM info extraction (using mock function for testing) + llm_response = LLM_info_extraction(remaining_chat, model_call_mode, **kwargs) + + info_set = parse_llm_output(llm_response) + + if isinstance(info_set, list): + break + else: + # If parsing failed, this is an error message + print(f"Attempt {attempt + 1} failed: {info_set}") + if attempt < max_retries - 1: + time.sleep(24) + except Exception as e: + print(f"Attempt {attempt + 1} failed with exception: {str(e)}") + if attempt < max_retries - 1: + time.sleep(24) # Shorter wait for testing + + if info_set is None: + raise Exception(f"failed to generate {session}") + data["info_set"] = info_set + processed_lines.append(json.dumps(data, ensure_ascii=False)) + + return processed_lines + + +# Example usage: +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_file", type=str, default="examples/learn_to_ask/data_raw/train_origin.jsonl") + parser.add_argument("--output_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl") + parser.add_argument("--model_call_mode", type=str, choices=["online_api", "local_vllm"], default="online_api") + args = parser.parse_args() + print( + process_jsonl_file( + input_file=args.input_file, + output_file=args.output_file, + model_call_mode=args.model_call_mode, + # Additional parameters for API calls + ) + ) diff --git a/tutorial/example_learn2ask/data_preprocess/step2.py b/tutorial/example_learn2ask/data_preprocess/step2.py new file mode 100644 index 00000000..45cda5c2 --- /dev/null +++ b/tutorial/example_learn2ask/data_preprocess/step2.py @@ -0,0 +1,55 @@ +import argparse +import json + + +def process_message(json_obj): + info_set = json_obj.get("info_set") + info_set_str = ", ".join(info_set) if isinstance(info_set, list) else "" + if "user: " not in json_obj["remaining_chat"]: + decision_str = "stop" + else: + decision_str = "continue" + if not info_set_str and decision_str == "continue": + if_keep = False + else: + if_keep = True + return if_keep, info_set_str, decision_str + + +def main(input_file_path, output_file_path): + with open(input_file_path, "r", encoding="utf-8") as infile, open(output_file_path, "w", encoding="utf-8") as outfile: + print("data processing started...") + for line in infile: + data = json.loads(line.strip()) + if_keep, info_set, decision = process_message(data) + if not if_keep: + continue + + new_item = { + "main_query": "[no query]", + "init_messages": data["messages"], + "task_id": data["cid"], + "env_type": "[no env]", + "metadata": { + "session_id": data["session_id"], + "diagn": data["diagn"], + "decision_truth": decision, + "info_truth": info_set, + }, + } + outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n") + print("job done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # The file generated by 1_info_extract_pipeline.py + parser.add_argument("--input_file", type=str, default="examples/learn_to_ask/data_raw/train_processed.jsonl") + + # The final file for training or testing + parser.add_argument("--output_file", type=str, default="examples/learn_to_ask/data/train.jsonl") + + args = parser.parse_args() + + main(args.input_file, args.output_file) diff --git a/tutorial/example_learn2ask/learn2ask.md b/tutorial/example_learn2ask/learn2ask.md new file mode 100644 index 00000000..811d37f9 --- /dev/null +++ b/tutorial/example_learn2ask/learn2ask.md @@ -0,0 +1,102 @@ +# Learning to Ask + +Traditional LLMs primarily function by generating a direct answer or completing text based on a given prompt or question, whereas the core of the **Learning to Ask** task is to train an agent to learn how to ask questions that elicit the most valuable information and best advance the task. + +![](https://img.alicdn.com/imgextra/i4/O1CN01m9WJCM1WJL1aJCSaS_!!6000000002767-2-tps-1024-559.png) + +This document demonstrates how to prepare data, build an agent and workflow, set up rewards, and ultimately train a 7B agent for this task. + +## 1. Prepare Dataset +Download [RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv) dataset from HuggingFace, and put the files in `data/realmedconv`. + +- [Option 1] Run the following command to preprocess the dataset: + + ```bash + export DASHSCOPE_API_KEY=your_api_key + + cd tutorial/example_learn2ask/data_preprocess + ./run_process.sh data/realmedconv + ``` + +- [Option 2] download the processed dataset from [here](TODO) and put the files in `data/realmedconv`. + ```bash + bash tutorial/example_learn2ask/data_preprocess/download_processed.sh + ``` + +You now will get two datasets: +- `train.jsonl`: the train split +- `test.jsonl`: the test split + +Next, we will prepare a workflow to train an agent with these data. + +## 2. Prepare Workflow +Refer to `tutorial/example_learn2ask/learn2ask.py` for details of workflow. + +In this workflow, we implement: +- `ExampleLearn2Ask`: the workflow and agent +- `reward_fn`: the reward of learn2ask task +- `llm_reward`: a llm-as-a-judge used in `reward_fn` + +The agent reads the messages between a doctor and a patient, and tries predicting the next question to collect information and diagnose the patient. The question is judged by a LLM to check if it is valuable and relevant. + +## 3. Prepare Training Configuration +Next, we need a YAML configuration file to tell AgentJet what data, workflow, reward to use. + +In `tutorial/example_learn2ask/learn2ask.yaml`, we set them respectively: +- Set experiment name. +- Use `dataset_file` reader to read the training and validation data. +- Set the pretrained model. +- Use `ExampleLearn2Ask` as the workflow. +- Set other training parameters. + +You can use this exisiting configuration file straigtly and fill the dataset path in `tutorial/example_learn2ask/learn2ask.yaml`. + +## 3. Start Training + +Start the training with + +```bash +ajet --conf tutorial/example_learn2ask/learn2ask.yaml +``` + +## 4. Results + +### Main Result + +![curve](https://img.alicdn.com/imgextra/i2/O1CN01STeE8G1ricp6jySto_!!6000000005665-2-tps-866-614.png) + +## Analysis + +### Sample 1 + + +**Before training.** +``` +User: Bitten by a midge + +Agent: Has irritation subsided past initial acute swelling phase afterwards unnoticed scratched overly causing secondary bacterial dermatitis chance affecting skin rendering needed topical antibiotic application options among azelaic acid cream!? +``` + +The agent provided a slightly related but completely unreasonable answer and failed to form any effective question to the user. + +**During Training, ~50% step.** +``` +User: Bitten by a midge + +Agent: Have these itchy bumps appeared soon after getting bitten?(A.) Occasionally immediately afterward ; B fkknore ×later throughout day )/) +``` + +The agent provided a relatively effective question, attempting to gather the user's general symptoms and offering two options. + +**After training.** +``` +User: Bitten by a midge + +Agent: Has itching or reddening appeared around this bite site recently without healing despite scratching frequently?(A,Yes indeed worse every evening B,No pain but small swollen spots persist day after incident/) +``` + +The agent's question is more precise and informative, providing two specific and effective options. This is effective in guiding the user to recall symptom details. + +## Next + +To learn more about the task and results on larger models, refer to [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441). diff --git a/tutorial/example_learn2ask/learn2ask.py b/tutorial/example_learn2ask/learn2ask.py new file mode 100644 index 00000000..a88fdc4e --- /dev/null +++ b/tutorial/example_learn2ask/learn2ask.py @@ -0,0 +1,200 @@ +import re +import time +import asyncio +import threading + +from agentscope.message import Msg +from loguru import logger + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from ajet.utils.robust_dashscope import RobustDashScopeChatModel + +system_prompt = """# Task +You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English. + +## Guidelines +- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices. +- Do not repeat any previous question. +- Your response must be a single sentence. +- If enough information has been gathered to make a medication suggestion, output only: +""" + +reward_prompt = """# Task +You are an evaluation assistant. The user will provide a dialogue history between a doctor and a patient. You must analyze the dialogue and evaluate the doctor's last message. + +# Grading Policy +## Format Score +- 1.0: The doctor's last message contains exactly **one question**. +- 0.5: The doctor's last message contains **two questions**. +- 0.0: The doctor's last message contains **three or more questions**. + +## Content Score +Reference Information contains the information that the doctor has not known. + +- 1.0: The question(s) **directly ask about** item in the Reference Information. +- 0.1: The question(s) are a general type of question that could be asked for any symptoms. +- 0.0: The question(s) are **irrelevant** to all items in the Reference Information. + +### You should + +- ONLY if the doctor asks a question that helps to collect information and diagnose the patient, it is a good question. +- A ambiguous question should get 0. + - For example, the doctor asks "How long have you been feeling this way?", but "this way" is not clear in the previous messages. + - For example, the doctor asks "Do you feel bad?". This is a meaningless question that does not provide any useful information. + +# Reference Information + +{} + +# Output Format +Explain your reasoning for the format and content scores clearly and concisely. +Insert only the format score as a float (e.g., 1.0, 0.5, 0.0) +Insert only the content score as a float (e.g., 1.0, 0.5, 0.0) + +> ✅ Important: +> - Output **exactly** the three tags shown above. +> - Do **not** include any additional text, explanation, or formatting outside the tags. +> - Scores must be based **only** on the doctor's **last message** and the provided Reference Information. +> - Ensure clarity and precision in your evaluation reasoning within the `` tag. +""" + + +llm = RobustDashScopeChatModel("qwen-plus", stream=False) + + +async def llm_reward(init_messages: list[Msg], response: str, truth_info: str): + def format_messages(messages: list[Msg]) -> str: + result_str = "" + for msg in messages: + if msg.role == "user": + result_str += f"patient: {msg.content}\n" + if msg.role == "assistant": + result_str += f"doctor: {msg.content}\n" + return result_str + + def parse_tag_string(text: str): + pattern = r"<(\w+)>(.*?)" + matches = re.findall(pattern, text) + result = {} + for tag, value in matches: + result[tag] = value + return result + + history = format_messages([] + init_messages + [Msg("assistant", response, role="assistant")]) + messages = [ + {"role": "system", "content": reward_prompt.format(truth_info)}, + {"role": "user", "content": history}, + ] + + try_count, max_retries = 0, 5 + while try_count <= max_retries: + try: + + async def get_content(): + from agentscope.model import ChatResponse + + response = await llm(messages) + + if isinstance(response, ChatResponse): + res = "".join([x["text"] for x in response.content if "text" in x]) + else: + res = "" + async for chunk in response: + res += "".join([x["text"] for x in chunk.content if "text" in x]) + return res + + content = await get_content() + score_dict = parse_tag_string(content) + return score_dict + except Exception as e: + if try_count > max_retries: + logger.warning("retried too many times, abort task.") + return None + else: + logger.warning(f"error: {e}, response:{response}, retrying...") + time.sleep(2**try_count) + + +async def reward_fn(init_messages: list[Msg], response: str, truth_action: str, truth_info: str): + """ + content_score: R_a, the reward for response quality + action_score: R_s, the reward for decision correctness + format_score: P, the reward for response format + """ + + action_response = "stop" if "" in response else "continue" + if truth_action == action_response: + action_score = 1.0 + if truth_action == "continue": + score_dict = await llm_reward(init_messages, response, truth_info) + if score_dict is not None: + format_score = float(score_dict.get("format_score", 0.0)) + content_score = float(score_dict.get("content_score", 0.0)) + else: + format_score, content_score = 0.0, 0.0 + else: + content_score = 1.0 + format_score = 1.0 if response == "" else 0.0 + else: + action_score, format_score, content_score = 0.0, 0.0, 0.0 + + # treat as self.train_mode == "Ra+Rs", the default setting + final_reward = action_score * (1 + 2 * content_score) + format_score + + return final_reward + + +_reward_semaphore = threading.Semaphore(16) + + +async def reward_fn_with_semaphore(*args, **kwargs): + get_sem_ok = False + while not get_sem_ok: + get_sem_ok = _reward_semaphore.acquire(blocking=False) + if not get_sem_ok: + await asyncio.sleep(1) + + try: + fn_result = await reward_fn(*args, **kwargs) + finally: + _reward_semaphore.release() + + return fn_result + + +class ExampleLearn2Ask(Workflow): + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + + messages = workflow_task.task.init_messages + assert isinstance(messages, list) + truth_action = workflow_task.task.metadata["decision_truth"] or "continue" + truth_info = workflow_task.task.metadata["info_truth"] + + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + toolkit=None, + memory=InMemoryMemory(), + max_iters=1, + ) + self.agent.set_console_output_enabled(False) + msg = [ + # Msg("system", system_prompt, role="system"), + *[Msg(name=x["role"], content=x["content"], role=x["role"]) for x in messages] + ] + result = await self.agent.reply(msg) + if isinstance(result.content, str): + response = result.content + elif isinstance(result.content, list): + response = result.content[0]["text"] # type: ignore + else: + raise NotImplementedError(f"do not know how to handle {type(result.content)}") + reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info) + return WorkflowOutput(reward=reward) diff --git a/tutorial/example_learn2ask/learn2ask.yaml b/tutorial/example_learn2ask/learn2ask.yaml new file mode 100644 index 00000000..acacbce2 --- /dev/null +++ b/tutorial/example_learn2ask/learn2ask.yaml @@ -0,0 +1,68 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_learn2ask_enhancedreward + task_reader: + type: dataset_file + dataset_file: + training: + file_path: your_file_path + validation: + file_path: your_file_path + + model: + path: Qwen/Qwen2.5-7B-Instruct + + rollout: + user_workflow: "tutorial.example_learn2ask.learn2ask->ExampleLearn2Ask" + force_disable_toolcalls: True + temperature: 1.0 + max_env_worker: 64 + num_repeat: 6 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + multi_turn: + max_sample_per_task: 2 + + compute_madness_checklist: + - "nonsense" + agent_madness_reward: 0.0 + + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + n_vllm_engine: 2 + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 100 + total_epochs: 100 + logger: swanlab + + +trinity: + synchronizer: + sync_offset: 1 + sync_method: nccl + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_learn2ask/learn2ask_langchain.py b/tutorial/example_learn2ask/learn2ask_langchain.py new file mode 100644 index 00000000..96b2cca1 --- /dev/null +++ b/tutorial/example_learn2ask/learn2ask_langchain.py @@ -0,0 +1,198 @@ +import re +import time +import asyncio +import threading + +from agentscope.message import Msg +from loguru import logger + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from ajet.utils.robust_dashscope import RobustDashScopeChatModel + +system_prompt = """# Task +You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English. + +## Guidelines +- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices. +- Do not repeat any previous question. +- Your response must be a single sentence. +- If enough information has been gathered to make a medication suggestion, output only: +""" + +reward_prompt = """# Task +You are an evaluation assistant. The user will provide a dialogue history between a doctor and a patient. You must analyze the dialogue and evaluate the doctor's last message. + +# Grading Policy +## Format Score +- 1.0: The doctor's last message contains exactly **one question**. +- 0.5: The doctor's last message contains **two questions**. +- 0.0: The doctor's last message contains **three or more questions**. + +## Content Score +Reference Information contains the information that the doctor has not known. + +- 1.0: The question(s) **directly ask about** item in the Reference Information. +- 0.1: The question(s) are a general type of question that could be asked for any symptoms. +- 0.0: The question(s) are **irrelevant** to all items in the Reference Information. + +### You should + +- ONLY if the doctor asks a question that helps to collect information and diagnose the patient, it is a good question. +- A ambiguous question should get 0. + - For example, the doctor asks "How long have you been feeling this way?", but "this way" is not clear in the previous messages. + - For example, the doctor asks "Do you feel bad?". This is a meaningless question that does not provide any useful information. + +# Reference Information + +{} + +# Output Format +Explain your reasoning for the format and content scores clearly and concisely. +Insert only the format score as a float (e.g., 1.0, 0.5, 0.0) +Insert only the content score as a float (e.g., 1.0, 0.5, 0.0) + +> ✅ Important: +> - Output **exactly** the three tags shown above. +> - Do **not** include any additional text, explanation, or formatting outside the tags. +> - Scores must be based **only** on the doctor's **last message** and the provided Reference Information. +> - Ensure clarity and precision in your evaluation reasoning within the `` tag. +""" + + +llm = RobustDashScopeChatModel("qwen-plus", stream=False) + + +async def llm_reward(init_messages: list[dict], response: str, truth_info: str): + def format_messages(messages: list[dict]) -> str: + result_str = "" + for msg in messages: + if msg["role"] == "user": + result_str += f"patient: {msg['content']}\n" + if msg["role"] == "assistant": + result_str += f"doctor: {msg['content']}\n" + return result_str + + def parse_tag_string(text: str): + pattern = r"<(\w+)>(.*?)" + matches = re.findall(pattern, text) + result = {} + for tag, value in matches: + result[tag] = value + return result + + history = format_messages([] + init_messages + [{"role": "assistant", "content": response}]) + messages = [ + {"role": "system", "content": reward_prompt.format(truth_info)}, + {"role": "user", "content": history}, + ] + + try_count, max_retries = 0, 5 + while try_count <= max_retries: + try: + + async def get_content(): + from agentscope.model import ChatResponse + + response = await llm(messages) + + if isinstance(response, ChatResponse): + res = "".join([x["text"] for x in response.content if "text" in x]) + else: + res = "" + async for chunk in response: + res += "".join([x["text"] for x in chunk.content if "text" in x]) + return res + + content = await get_content() + score_dict = parse_tag_string(content) + return score_dict + except Exception as e: + if try_count > max_retries: + logger.warning("retried too many times, abort task.") + return None + else: + logger.warning(f"error: {e}, response:{response}, retrying...") + time.sleep(2**try_count) + + +async def reward_fn(init_messages: list[dict], response: str, truth_action: str, truth_info: str): + """ + content_score: R_a, the reward for response quality + action_score: R_s, the reward for decision correctness + format_score: P, the reward for response format + """ + + action_response = "stop" if "" in response else "continue" + if truth_action == action_response: + action_score = 1.0 + if truth_action == "continue": + score_dict = await llm_reward(init_messages, response, truth_info) + if score_dict is not None: + format_score = float(score_dict.get("format_score", 0.0)) + content_score = float(score_dict.get("content_score", 0.0)) + else: + format_score, content_score = 0.0, 0.0 + else: + content_score = 1.0 + format_score = 1.0 if response == "" else 0.0 + else: + action_score, format_score, content_score = 0.0, 0.0, 0.0 + + # treat as self.train_mode == "Ra+Rs", the default setting + final_reward = action_score * (1 + 2 * content_score) + format_score + + return final_reward + + +_reward_semaphore = threading.Semaphore(16) + + +async def reward_fn_with_semaphore(*args, **kwargs): + get_sem_ok = False + while not get_sem_ok: + get_sem_ok = _reward_semaphore.acquire(blocking=False) + if not get_sem_ok: + await asyncio.sleep(1) + + try: + fn_result = await reward_fn(*args, **kwargs) + finally: + _reward_semaphore.release() + + return fn_result + + +class ExampleLearn2Ask(Workflow): + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from langchain_openai import ChatOpenAI + from langchain.agents import create_agent + + messages = workflow_task.task.init_messages + assert isinstance(messages, list) + truth_action = workflow_task.task.metadata["decision_truth"] or "continue" + truth_info = workflow_task.task.metadata["info_truth"] + + llm_info = tuner.as_oai_baseurl_apikey() + + llm = ChatOpenAI( + base_url=llm_info.base_url, + api_key=lambda: llm_info.api_key, + ) + + agent = create_agent( + model=llm, + system_prompt=system_prompt, + ) + + msg = [{"role": x["role"], "content": x["content"]} for x in messages] + result = agent.invoke( + { + "messages": msg, # type: ignore + } + ) + + response = result["messages"][-1].content + reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info) + return WorkflowOutput(reward=reward) diff --git a/tutorial/example_ma_deepresearch/ma_deepresearch.py b/tutorial/example_ma_deepresearch/ma_deepresearch.py new file mode 100644 index 00000000..9eaba34c --- /dev/null +++ b/tutorial/example_ma_deepresearch/ma_deepresearch.py @@ -0,0 +1,74 @@ +from typing import List +from loguru import logger +from pydantic import BaseModel, Field +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat import ChatCompletionMessageToolCall +from textwrap import dedent + +import json +import os +import asyncio +import requests + + +# ------------------------------------------------------ +# Simple version - no tool call +# ------------------------------------------------------ + + +class DeepResearchInputSchema(BaseModel): + base_url: str = Field(default="", description="The base URL of the OpenAI-compatible API.") + api_key: str = Field(default="", description="The API key for authentication.") + init_messages: List[dict] = Field(default=[], description="The initial messages for the deep research task.") + task_id: str = Field(default="", description="The unique identifier for the research task.") + main_query: str = Field(default="", description="The main query for the research task.") + max_steps: int = Field(default=20, description="The maximum number of steps for the research task.") + env_service_url: str = Field(default="", description="The URL of the environment service.") + + +class ExampleMaDeepResearch(Workflow): + name: str = "multiagent_deep_research_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore + # Extract base URL and API key from the tuner + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + init_messages = workflow_task.task.init_messages + + # Get the AGENT_SERVER_URL from environment variables or use a default value + agent_server_url = os.getenv("AGENT_SERVER_URL", "http://localhost:8000") + + # Prepare the payload using DeepResearchInputSchema + payload = DeepResearchInputSchema( + base_url=base_url, + api_key=api_key, + init_messages=init_messages, + task_id=workflow_task.task.task_id, + main_query=workflow_task.task.main_query, + max_steps=tuner.config.astune.rollout.multi_turn.max_steps, + env_service_url=workflow_task.gym_env.service_url, + ) + + try: + # Send the HTTP POST request to the AGENT_SERVER_URL + headers = { + "Content-Type": "application/json", + } + + response = requests.post( + agent_server_url, + headers=headers, + data=payload.model_dump(), + ) + + # Check if the request was successful + if response.status_code == 200: + result_data = response.json() + logger.info(f"Successfully received response: {result_data}") + result = WorkflowOutput(**result_data) + return result + + except Exception as e: + logger.error(f"An error occurred while sending the request: {e}") diff --git a/tutorial/example_math_agent/math_agent.md b/tutorial/example_math_agent/math_agent.md new file mode 100644 index 00000000..894f34aa --- /dev/null +++ b/tutorial/example_math_agent/math_agent.md @@ -0,0 +1,4 @@ +# Training a basic math agent + + +Please refer to document at [`docs/en/example_app_world.md`](docs/en/example_app_world.md) diff --git a/tutorial/example_math_agent/math_agent.py b/tutorial/example_math_agent/math_agent.py new file mode 100644 index 00000000..3c5d61df --- /dev/null +++ b/tutorial/example_math_agent/math_agent.py @@ -0,0 +1,55 @@ +from agentscope.message import Msg +from loguru import logger + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + +def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if hasattr(result, "metadata") and isinstance(result.metadata, dict) and "result" in result.metadata: + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + +system_prompt = """ +You are an agent specialized in solving math problems with tools. +Please solve the math problem given to you. +You can write and execute Python code to perform calculation or verify your answer. +You should return your final answer within \\boxed{{}}. +""" + + +class ExampleMathLearn(Workflow): + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), + max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) diff --git a/tutorial/example_math_agent/math_agent.yaml b/tutorial/example_math_agent/math_agent.yaml new file mode 100644 index 00000000..59f7889a --- /dev/null +++ b/tutorial/example_math_agent/math_agent.yaml @@ -0,0 +1,76 @@ +# ------------------ main configuration ------------------ +ajet: + project_name: example_math_agent + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` + # effective when `type: huggingface_dat_repo` + huggingface_dat_repo: + dataset_path: '/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main' + training_split: "train" + validation_split: "test" + + task_judge: + # ✨✨✨✨ define your evaluation function + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAsJudge + + model: + # ✨✨✨✨ set the model to be trained + path: Qwen/Qwen2.5-7B + + rollout: + user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ write and select workflow + # user_workflow: "tutorial.example_math_agent.math_agent_langchain->ExampleMathLearn" # ✨if you prefer langchain version + # user_workflow: "tutorial/example_math_agent/math_agent_oai_sdk.py->ExampleMathLearn_Simple_NoToolCall" # ✨if you prefer openai sdk version without toolcall + # user_workflow: "tutorial/example_math_agent/math_agent_oai_sdk.py->ExampleMathLearn" # ✨if you prefer openai sdk version with toolcall + # user_workflow: "tutorial/example_math_agent/math_agent_raw_http.py->ExampleMathLearn" # ✨if you do not want to use any agentic framwork at all + # user_workflow: "tutorial/example_math_agent/math_agent_simplify.py->MathToolWorkflow" # ✨if you prefer to compute reward inside workflow + temperature: 1.0 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: 0.0 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + multi_turn: + max_sample_per_task: 2 + compute_madness_checklist: + - "nonsense" + - "wrong_toolcall" + max_response_length_in_one_turn: 1024 + max_model_len: 10000 + n_vllm_engine: 2 + + data: + train_batch_size: 100 + max_prompt_length: 3000 + max_response_length: 7000 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + trainer_common: + save_freq: 100 + test_freq: 100 + total_epochs: 100 + logger: swanlab + + +trinity: + synchronizer: + sync_offset: 1 + sync_method: nccl + + +# ------------------ do not modify ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + - file://ajet/default_config/trinity + +# ------------------ do not modify ------------------ +defaults: + - verl_default + - trinity_default + - ajet_default + - _self_ diff --git a/tutorial/example_math_agent/math_agent_langchain.py b/tutorial/example_math_agent/math_agent_langchain.py new file mode 100644 index 00000000..c11f53cc --- /dev/null +++ b/tutorial/example_math_agent/math_agent_langchain.py @@ -0,0 +1,56 @@ +from loguru import logger +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat import ChatCompletionMessageToolCall +from textwrap import dedent + +import json +import asyncio +import requests +from langchain.agents import create_agent + + +# ------------------------------------------------------ +# Simple version - no tool call +# ------------------------------------------------------ + + +class ExampleMathLearn(Workflow): + name: str = "math_agent_workflow" + system_prompt: str = dedent( + """ + You are an agent specialized in solving math problems. + Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. + You should return your final answer within \\boxed{{}}. + """ + ) + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore + # tuner to api key + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + + from langchain_openai import ChatOpenAI + + llm = ChatOpenAI( + base_url=base_url, + api_key=lambda: api_key, + ) + agent = create_agent( + model=llm, + system_prompt=self.system_prompt, + ) + + # take out query + query = workflow_task.task.main_query + + response = agent.invoke( + { + "messages": [{"role": "user", "content": query}], + } + ) + + final_answer = response["messages"][-1].content + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) diff --git a/tutorial/example_math_agent/math_agent_oai_sdk.py b/tutorial/example_math_agent/math_agent_oai_sdk.py new file mode 100644 index 00000000..efd6e03d --- /dev/null +++ b/tutorial/example_math_agent/math_agent_oai_sdk.py @@ -0,0 +1,101 @@ +from loguru import logger +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat import ChatCompletionMessageToolCall +from textwrap import dedent + +import json +import asyncio + + +# ------------------------------------------------------ +# Simple version - no tool call +# ------------------------------------------------------ + + +class ExampleMathLearn_Simple_NoToolCall(Workflow): + name: str = "math_agent_workflow" + system_prompt: str = dedent( + """ + You are an agent specialized in solving math problems. + Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. + You should return your final answer within \\boxed{{}}. + """ + ) + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore + query = workflow_task.task.main_query + client = tuner.as_raw_openai_sdk_client() + + messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": query}] + reply_message: ChatCompletion = await client.chat.completions.create(messages=messages) + final_answer = reply_message.choices[0].message.content + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) + + +# ------------------------------------------------------ +# Tool use version +# ------------------------------------------------------ + + +class ExampleMathLearn(Workflow): + name: str = "math_agent_workflow" + system_prompt: str = dedent( + """ + You are an agent specialized in solving math problems with tools. + Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. + You should return your final answer within \\boxed{{}}. + """ + ) + available_functions: list = [ + {"type": "function", "function": {"name": "execute_python_code", "description": "Execute the given Python code in a temp file and capture the return code, standard output, and error. Note that you should print something or you will get empty return.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The Python code to be executed."}, "timeout": {"type": "number", "description": "The maximum time (in seconds) allowed for the code to run.", "default": 300}}, "required": ["code"]}}}, + ] + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore + query = workflow_task.task.main_query + client = tuner.as_raw_openai_sdk_client() + + # call 1: get response with tool call + messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": query}] + reply_message: ChatCompletion = await client.chat.completions.create(messages=messages, tools=self.available_functions) + if reply_message.choices[0].message.content: + messages.append({"role": "assistant", "content": reply_message.choices[0].message.content}) + + # If the model called a tool + if (reply_message.choices[0].message) and (reply_message.choices[0].message.tool_calls): + tool_calls: list[ChatCompletionMessageToolCall] = reply_message.choices[0].message.tool_calls + for tool_call in tool_calls: + if tool_call.function.name == "execute_python_code": + arguments = json.loads(tool_call.function.arguments) + + def sync_wrapper(): + import subprocess + import sys + + process = subprocess.run([sys.executable, "-c", arguments["code"]], timeout=arguments.get("timeout", 300), capture_output=True, text=True) + return process.stdout + + result = await asyncio.to_thread(sync_wrapper) + tool_result_message = { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_call.function.name, + "content": json.dumps( + { + "return_code": str(result), + } + ), + } + messages.append(tool_result_message) + + # Step 3: Make a follow-up API call with the tool result + final_response: ChatCompletion = await client.chat.completions.create( + messages=messages, + ) + final_stage_response = final_response.choices[0].message.content + else: + final_stage_response = reply_message.choices[0].message.content + + return WorkflowOutput(reward=None, metadata={"final_answer": final_stage_response}) diff --git a/tutorial/example_math_agent/math_agent_raw_http.py b/tutorial/example_math_agent/math_agent_raw_http.py new file mode 100644 index 00000000..af7ed484 --- /dev/null +++ b/tutorial/example_math_agent/math_agent_raw_http.py @@ -0,0 +1,49 @@ +from loguru import logger +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat import ChatCompletionMessageToolCall +from textwrap import dedent + +import json +import asyncio +import requests + + +# ------------------------------------------------------ +# Simple version - no tool call +# ------------------------------------------------------ + + +class ExampleMathLearn(Workflow): + name: str = "math_agent_workflow" + system_prompt: str = dedent( + """ + You are an agent specialized in solving math problems. + Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. + You should return your final answer within \\boxed{{}}. + """ + ) + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore + # tuner to api key + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key + + # take out query + query = workflow_task.task.main_query + + messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": query}] + + # use raw http requests (non-streaming) to get response + response = requests.post( + f"{base_url}/chat/completions", + json={ + "model": "whatever", # Of course, this `model` field will be ignored. + "messages": messages, + }, + headers={"Authorization": f"Bearer {api_key}"}, + ) + final_answer = response.json()["choices"][0]["message"]["content"] + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) diff --git a/tutorial/example_math_agent/math_agent_simplify.py b/tutorial/example_math_agent/math_agent_simplify.py new file mode 100644 index 00000000..4b1b8b79 --- /dev/null +++ b/tutorial/example_math_agent/math_agent_simplify.py @@ -0,0 +1,68 @@ +import re +from loguru import logger +from agentscope.message import Msg +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + +def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if hasattr(result, "metadata") and isinstance(result.metadata, dict) and "result" in result.metadata: + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + +system_prompt = """ +You are an agent specialized in solving math problems with tools. +Please solve the math problem given to you. +You can write and execute Python code to perform calculation or verify your answer. +You should return your final answer within \\boxed{{}}. +""" + + +class MathToolWorkflow(Workflow): + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + + # run agentscope + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), + max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # compute reward + reference_answer = workflow_task.task.metadata["answer"] + reference_answer = reference_answer.split("####")[-1].strip() + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + if match: + result = match.group(1) + is_success = result == reference_answer + else: + is_success = False + raw_reward = 1.0 if is_success else 0.0 + return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) diff --git a/tutorial/example_math_agent/math_agentscope_urlkey.py b/tutorial/example_math_agent/math_agentscope_urlkey.py new file mode 100644 index 00000000..874fd91e --- /dev/null +++ b/tutorial/example_math_agent/math_agentscope_urlkey.py @@ -0,0 +1,78 @@ +import re +from loguru import logger +from agentscope.message import Msg +from agentscope.agent import ReActAgent +from agentscope.formatter import OpenAIChatFormatter +from agentscope.model import OpenAIChatModel +from agentscope.memory import InMemoryMemory +from agentscope.tool import Toolkit, execute_python_code +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + +def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if hasattr(result, "metadata") and isinstance(result.metadata, dict) and "result" in result.metadata: + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + +system_prompt = """ +You are an agent specialized in solving math problems with tools. +Please solve the math problem given to you. +You can write and execute Python code to perform calculation or verify your answer. +You should return your final answer within \\boxed{{}}. +""" + + +class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # run agentscope + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + + url_and_apikey = tuner.as_oai_baseurl_apikey() + base_url = url_and_apikey.base_url + api_key = url_and_apikey.api_key # the api key contain information, do not discard it + + # print(f"[MathToolWorkflow] Using base_url: [{base_url}], api_key: [{api_key}]") + # base_url: [http://10.56.3.98:57817/v1], api_key: [sk-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa] + + model = OpenAIChatModel( + model_name="whatever", + client_args={"base_url": base_url}, + api_key=api_key, + stream=False, + ) + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=model, # ✨✨ compared with a normal agentscope agent, here is the difference! + formatter=OpenAIChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), + max_iters=2, + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg) + final_answer = extract_final_answer(result) + + # compute reward + reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip() + match = re.search(r"\\boxed\{([^}]*)\}", final_answer) + if match: + is_success = match.group(1) == reference_answer + else: + is_success = False + return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer}) diff --git a/tutorial/example_math_agent/math_answer_as_judge.py b/tutorial/example_math_agent/math_answer_as_judge.py new file mode 100644 index 00000000..d2b7a81f --- /dev/null +++ b/tutorial/example_math_agent/math_answer_as_judge.py @@ -0,0 +1,61 @@ +import re + +from ajet.task_judge.base_judge import BaseJudge +from ajet.task_rollout.dashscope_llm_bridge import create_external_llm_fn +from ajet.workflow import WorkflowOutput, WorkflowTask + + +class MathAnswerAsJudge(BaseJudge): + def __init__(self, config): + self.config = config + + def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple: + raw_reward = 0 + final_answer = workflow_output.metadata["final_answer"] # By default there's no final_answer; register it by calling ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) in the workflow + reference_answer = workflow_task.task.metadata["answer"] + reference_answer = reference_answer.split("####")[-1].strip() + + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + if match: + result = match.group(1) + is_success = result == reference_answer + else: + is_success = False + + raw_reward = 1.0 if is_success else 0.0 + return raw_reward, is_success + + +class MathAnswerAndLlmAsJudge(BaseJudge): + def __init__(self, config): + self.config = config + + def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple: + raw_reward = 0 + final_answer = workflow_output.metadata["final_answer"] # By default there's no final_answer; register it by calling ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) in the workflow + reference_answer = workflow_task.task.metadata["answer"] + reference_answer = reference_answer.split("####")[-1].strip() + + external_llm_fn = create_external_llm_fn( + alien_llm_model=self.config.ajet.task_judge.alien_llm_model, + alien_llm_response_length=self.config.ajet.task_judge.alien_llm_response_length, + ) + messages = [ + { + "role": "system", + "content": "Is my result correct? If correct, say , otherwise say .", + }, + { + "role": "user", + "content": f"Is my result correct?\n\n\n----\nMy result: {final_answer}\n\n\n----\nReal result: {reference_answer}", + }, + ] + res = external_llm_fn(messages=messages) + if "" in res["content"]: + is_success = True + raw_reward = 1.0 + else: + is_success = False + raw_reward = 0.0 + return raw_reward, is_success diff --git a/tutorial/example_rm_auto_grader/auto_grader.md b/tutorial/example_rm_auto_grader/auto_grader.md new file mode 100644 index 00000000..897f9510 --- /dev/null +++ b/tutorial/example_rm_auto_grader/auto_grader.md @@ -0,0 +1,263 @@ +# Auto Grader Judge + +A data-driven judge that automatically generates evaluation rubrics from reference samples using RM Gallery's Iterative Rubrics Generator. + +## What is Auto Grader Judge? + +Auto Grader Judge is an intelligent evaluation system that **learns how to grade your AI agent's outputs** by analyzing examples of good and bad responses. Instead of manually writing evaluation rules, it automatically discovers scoring criteria from your training data through an iterative Propose-Evaluate-Revise process, then generates structured rubrics (Theme-Tips format) that can be inspected and understood. + +**Key Features:** +- Automatic rubric generation from reference samples +- Support for both pointwise (scoring) and listwise (ranking) evaluation +- MCR²-based smart sampling for large datasets +- Optional LLM-based categorization +- Seamless integration with ajet's workflow system + +### When to Use Auto Grader Judge? + +**✅ Ideal For:** +- **Open-ended tasks**: Dialogue generation, creative writing, explanations +- **Subjective quality assessment**: Where "correctness" has nuance (helpfulness, clarity, style) +- **Complex multi-aspect evaluation**: Need to assess accuracy, completeness, fluency, etc. +- **Large-scale RL training**: Need automated, consistent evaluation with reward signals + +**⚠️ Not Recommended For:** +- **Tasks with exact answers**: Use `EnvServiceJudge` or exact match instead +- **Fully objective tasks**: API calls, code execution, mathematical computation + +## Quick Start + +### 1. Configuration + +Add to your `ajet_default.yaml`: + +```yaml +ajet: + task_judge: + judge_type: rubrics_auto_grader + + rubrics_auto_grader: + # Model settings + model_name: qwen-max + + # Grader configuration + grader_mode: pointwise # or "listwise" + language: en # or "zh" + + # auto grader configuration + query_specific_generate_number: 1 + enable_categorization: false + categories_number: 5 + + # Custom evaluation prompt + custom_evaluation_prompt: null + + # Field mappings + query_field: main_query + answer_field: final_answer + reference_field: answer + + # Training data + input_data_type: dataset_file + dataset_file: + training: + file_path: "path/to/training_data.jsonl" + + # Pointwise mode only + min_score: 0 + max_score: 10 +``` + +### 2. Training Data Format + +#### Pointwise Mode +Each sample contains a query, answer, and score: + +```json +{ + "main_query": "What is 2 + 2?", + "metadata": { + "answer": "2 + 2 = 4", + "score": 1 + } +} +``` + +#### Listwise Mode +Each sample contains a query with multiple ranked candidates: + +```json +{ + "main_query": "What is 2 + 2?", + "metadata": { + "candidates": [ + {"answer": "2 + 2 = 4", "rank": 1}, + {"answer": "2 + 2 = 5", "rank": 2}, + {"answer": "I don't know", "rank": 3} + ] + } +} +``` + +### 3. Basic Usage + +```python +from ajet.task_judge.rm_auto_grader_judge import AutoGraderJudge + +# Initialize judge +judge = AutoGraderJudge(config) + +# Generate rubrics (one-time setup) +await judge.generate_rubrics_from_samples() + +# Or load from cache +await judge.load_rubrics_from_cache() + +# Evaluate outputs +result = await judge._async_compute_reward(task, workflow_output) + +# For pointwise: result is a GraderScore object +print(f"Score: {result.score}, Reason: {result.reason}") + +# For listwise: result is a GraderRank object +print(f"Ranks: {result.rank}, Reason: {result.reason}") +``` + +## Configuration Parameters + +### Required Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `model_name` | str | DashScope model name (e.g., qwen-max, qwen-plus) | +| `grader_mode` | str | Evaluation mode: "pointwise" or "listwise" | +| `language` | str | Language: "en" or "zh" | +| `input_data_type` | str | Data source type: "dataset_file", "env_service", etc. | + +### Field Mapping + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `query_field` | "main_query" | Field name containing the query | +| `answer_field` | "final_answer" | Field name containing the answer | +| `reference_field` | "answer" | Field name containing the reference | + +### Pointwise Mode Settings + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `min_score` | 0 | Minimum score value | +| `max_score` | 10 | Maximum score value | + +### Advanced Options + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `grader_name` | "RM Iterative Rubric Grader" | Name of the grader | +| `enable_categorization` | false | Enable LLM-based rubric categorization | +| `categories_number` | 5 | Number of categories (if enabled) | +| `query_specific_generate_number` | 1 | Number of rubrics per sample | +| `custom_evaluation_prompt` | null | Path to custom evaluation prompt file | + + +## Evaluation Modes + +### Pointwise Mode +Evaluates individual outputs independently. + +**Returns:** `GraderScore` object +- `score`: Numerical score (float) +- `reason`: Explanation for the score +- `metadata`: Additional information + +**Use case:** Absolute quality assessment + +### Listwise Mode +Ranks multiple outputs together. + +**Returns:** `GraderRank` object +- `rank`: List of rankings (e.g., [1, 3, 2]) +- `reason`: Explanation for the ranking +- `metadata`: Additional information + +**Use case:** Relative comparison, preference ranking + +## Cache Management + +Generated rubrics are automatically saved to: +``` +{experiment_dir}/auto_grader.json +``` + +To reuse rubrics: +```python +await judge.load_rubrics_from_cache() +``` + +This skips the generation phase and loads pre-generated rubrics. + +## Example: Pointwise Evaluation + +```python +# Create reference samples +reference_samples = [ + Task( + task_id="1", + main_query="What is 5 + 3?", + metadata={"answer": "5 + 3 = 8", "score": 1} + ), + # ... more samples +] + +# Initialize and generate rubrics +judge = AutoGraderJudge(config) +await judge.generate_rubrics_from_samples(reference_samples) + +# Create test task and output +test_task = Task(task_id="test_1", main_query="What is 7 + 2?") +test_output = WorkflowOutput(metadata={"final_answer": "7 + 2 = 9"}) + +# Evaluate +result = await judge._async_compute_reward(test_task, test_output) + +print(f"Score: {result.score}") +print(f"Reasoning: {result.reason}") +``` + +## Example: Listwise Evaluation + +```python +# Create reference samples with rankings +reference_samples = [ + Task( + task_id="1", + main_query="Explain photosynthesis", + metadata={ + "candidates": [ + {"answer": "Detailed scientific explanation...", "rank": 1}, + {"answer": "Brief explanation...", "rank": 2}, + {"answer": "Incorrect explanation...", "rank": 3} + ] + } + ), + # ... more samples +] + +# Initialize and generate rubrics +judge = AutoGraderJudge(config) +await judge.generate_rubrics_from_samples(reference_samples) + +# Create test task with multiple candidates +test_task = Task(task_id="test_1", main_query="What is the water cycle?") +candidates = [ + WorkflowOutput(metadata={"final_answer": "Water evaporates, forms clouds, and rains"}), + WorkflowOutput(metadata={"final_answer": "It's when water moves around"}), + WorkflowOutput(metadata={"final_answer": "Detailed explanation of evaporation, condensation, precipitation..."}) +] + +# Evaluate +result = await judge._async_compute_reward(test_task, candidates) + +print(f"Rankings: {result.rank}") # e.g., [2, 3, 1] +print(f"Reasoning: {result.reason}") +``` diff --git a/tutorial/example_rm_auto_grader/auto_grader_example.py b/tutorial/example_rm_auto_grader/auto_grader_example.py new file mode 100644 index 00000000..cfd08afc --- /dev/null +++ b/tutorial/example_rm_auto_grader/auto_grader_example.py @@ -0,0 +1,454 @@ +""" +Example: Using RM Iterative Rubric Judge with ajet + +This example demonstrates how to use the RM Gallery IterativeRubricsGenerator integration +for data-driven evaluation of workflow outputs. + +The IterativeRubricsGenerator uses an iterative Propose-Evaluate-Revise loop to +generate high-quality evaluation rubrics from reference samples. + +This example shows: +1. Pointwise evaluation mode (scoring individual outputs) +2. Listwise evaluation mode (ranking multiple outputs) +""" + +import asyncio +from typing import List + +from rm_gallery.core.generator.iterative_rubric.query_rubric_generator import ( + LISTWISE_EVALUATION_TEMPLATE, + POINTWISE_EVALUATION_TEMPLATE, +) + +from ajet.schema.task import Task +from ajet.task_judge.rm_auto_grader_judge import AutoGraderJudge +from ajet.workflow import WorkflowOutput, WorkflowTask + +# ============================================ +# Example 1: Pre-generated Rubrics Approach +# ============================================ + + +async def example_pregerated_rubrics(): + """ + Example of using AutoGraderJudge with iteratively-generated rubrics (Pointwise mode). + + This approach uses the IterativeRubricsGenerator to automatically create + evaluation rubrics from reference samples using a Propose-Evaluate-Revise loop. + """ + print("\n\nExample 1: Pointwise Evaluation with Iterative Rubrics") + + # Mock config object + class MockConfig: + class Ajet: + class TaskJudge: + class RubricsAutoGrader: + # Model configuration + model_name = "qwen3-32b" + + # Grader configuration + grader_mode = "pointwise" + language = "en" + min_score = 0 + max_score = 1 + + # Evaluation prompt template + custom_evaluation_prompt = POINTWISE_EVALUATION_TEMPLATE + + # Advanced configuration (optional) + query_specific_generate_number = 1 + max_epochs = 2 + max_retries = 3 + enable_categorization = False + + # Field mappings + query_field = "main_query" + answer_field = "final_answer" + reference_field = "answer" + + grader_name = "Math Iterative Rubric Grader" + + rubrics_auto_grader = RubricsAutoGrader() + + task_judge = TaskJudge() + experiment_dir = "/tmp/rm_grader_example" + + ajet = Ajet() + + config = MockConfig() + + # Step 1: Create reference samples for rubric generation + reference_samples = create_math_reference_samples(num_samples=10) + + # Step 2: Initialize judge + judge = AutoGraderJudge(config) + + # Step 3: Generate rubrics from reference samples using iterative refinement + await judge.generate_rubrics_from_samples(reference_samples) + + # Step 4: Evaluate new samples using generated rubrics + test_samples = create_math_test_samples(num_samples=5) + + for i, (workflow_task, output) in enumerate(test_samples, 1): + print(f"\n--- Test Sample {i} ---") + print(f"Query: {workflow_task.task.main_query}") + print(f"Answer: {output.metadata['final_answer']}") + print(f"Reference: {workflow_task.task.metadata['answer']}") + + # Use async method directly since we're in async context + reward = await judge._async_compute_reward(workflow_task.task, output) + print(f"Result: {reward}") + + print("Example 1 completed!") + + +# ============================================ +# Example 2: Listwise Mode with Multiple Outputs +# ============================================ + + +async def example_listwise_mode(): + """ + Example of using AutoGraderJudge in Listwise mode with iterative rubrics. + + Listwise mode ranks multiple candidate answers for the same query. + This is useful for: + - Comparing multiple model outputs + - Ranking candidate responses by quality + - Batch evaluation of similar tasks + """ + print("\n\nExample 2: Listwise Ranking with Iterative Rubrics") + + # Mock config object + class MockConfig: + class Ajet: + class TaskJudge: + class RubricsAutoGrader: + # Model configuration + model_name = "qwen3-32b" + + # Grader configuration - LISTWISE mode + grader_mode = "listwise" # Key difference! + language = "en" + # Note: min_score/max_score not needed for listwise mode + + # Evaluation prompt template + custom_evaluation_prompt = LISTWISE_EVALUATION_TEMPLATE + + # Advanced configuration (optional) + query_specific_generate_number = 2 + max_epochs = 2 + max_retries = 3 + enable_categorization = False + + # Field mappings + query_field = "main_query" + answer_field = "final_answer" + reference_field = "answer" + + grader_name = "Math Listwise Iterative Grader" + + rubrics_auto_grader = RubricsAutoGrader() + + task_judge = TaskJudge() + experiment_dir = "/tmp/rm_grader_example_listwise" + + ajet = Ajet() + + config = MockConfig() + + # Step 1: Create reference samples with multiple outputs per query + reference_samples = create_listwise_reference_samples(num_samples=5) + + # Step 2: Initialize judge + judge = AutoGraderJudge(config) + + # Step 3: Generate ranking rubrics using iterative refinement + await judge.generate_rubrics_from_samples(reference_samples) + + # Step 4: Evaluate multiple candidate answers for new queries + test_queries = create_listwise_test_samples(num_queries=3) + + for i, (workflow_task, candidate_outputs) in enumerate(test_queries, 1): + print(f"\n{'='*50}") + print(f"Query {i}: {workflow_task.task.main_query}") + print(f"{'='*50}") + print(f"Evaluating {len(candidate_outputs)} candidates...") + + # Evaluate all candidates together (pass list for listwise mode) + grader_rank_result = await judge._async_compute_reward(workflow_task.task, candidate_outputs) + + if grader_rank_result and hasattr(grader_rank_result, "rank"): + ranks = grader_rank_result.rank + reason = grader_rank_result.reason + + print(f"\nGrader reasoning: {reason}") + + results = [] + for j, (output, rank) in enumerate(zip(candidate_outputs, ranks), 1): + results.append((j, output.metadata["final_answer"], rank)) + + # Sort by rank (ascending, rank 1 is best) + results.sort(key=lambda x: x[2]) + + print("\nRanking (best to worst):") + for display_rank, (idx, answer, model_rank) in enumerate(results, 1): + print(f" {display_rank}. Candidate {idx}: '{answer}' (Model Rank: {model_rank})") + else: + print("No results returned from evaluation") + + print("\n" + "=" * 60) + print("Example 2 completed!") + print("=" * 60) + + +# ============================================ +# Helper Functions +# ============================================ + + +def create_math_reference_samples(num_samples: int = 10) -> List[Task]: + """ + Create reference math problem samples for Pointwise rubric generation. + + Each sample contains a single answer with a score label. + """ + samples = [] + + # Simple math problems with answers and scores + # Format: (query, answer, score) + problems = [ + ("What is 15 + 27?", "42", 1), + ("Calculate 8 * 9", "72", 1), + ("What is 100 - 37?", "63", 1), + ("Find the value of 144 / 12", "12", 1), + ("What is 5^3?", "125", 1), + ("Calculate 23 + 45 - 18", "50", 1), + ("What is 7 * 8 + 6?", "62", 1), + ("Find the value of (15 + 5) * 2", "40", 1), + ("What is 99 - 33 - 22?", "44", 1), + ("Calculate 16 / 4 + 10", "14", 1), + ] + + for i in range(min(num_samples, len(problems))): + query, answer, score = problems[i] + + task = Task( + main_query=query, + task_id=f"ref_sample_{i}", + metadata={"answer": answer, "score": score}, # Pointwise label + ) + + samples.append(task) + + return samples + + +def create_listwise_reference_samples(num_samples: int = 5) -> List[Task]: + """ + Create reference samples for Listwise mode rubric generation. + + Each sample should contain multiple outputs with different quality levels. + This helps the model learn to distinguish between good and bad answers. + """ + samples = [] + + # Math problems with multiple candidate answers and their quality rankings + # Lower rank = better quality (rank 1 is best) + problems = [ + ( + "What is 10 + 15?", + [ + ("25", 1), # Perfect answer + ("Twenty-five", 2), # Correct but different format + ("24", 3), # Close but wrong + ("30", 4), # Wrong + ], + ), + ( + "Calculate 6 * 7", + [ + ("42", 1), # Perfect + ("6*7=42", 2), # Correct with work shown + ("43", 3), # Off by one + ("36", 4), # Wrong (6*6) + ], + ), + ( + "What is 50 - 18?", + [ + ("32", 1), # Perfect + ("50-18=32", 2), # Correct with work + ("33", 3), # Close + ("42", 4), # Wrong + ], + ), + ( + "Find 12 / 4", + [ + ("3", 1), # Perfect + ("3.0", 2), # Correct, decimal format + ("4", 3), # Wrong + ("2", 4), # Very wrong + ], + ), + ( + "What is 2^5?", + [ + ("32", 1), # Perfect + ("2*2*2*2*2=32", 2), # Correct with work + ("16", 3), # 2^4, common mistake + ("10", 4), # 2*5, wrong operation + ], + ), + ] + + for i in range(min(num_samples, len(problems))): + query, candidates = problems[i] + + # Create task with metadata containing all candidates and their ranks + task = Task( + main_query=query, + task_id=f"listwise_ref_{i}", + metadata={"candidates": [{"answer": ans, "rank": rank} for ans, rank in candidates]}, + ) + + samples.append(task) + + return samples + + +def create_listwise_test_samples( + num_queries: int = 3, +) -> List[tuple[WorkflowTask, List[WorkflowOutput]]]: + """ + Create test queries with multiple candidate outputs for Listwise evaluation. + + Returns: + List of (query_task, list_of_candidate_outputs) tuples + """ + test_data = [] + + # Test queries with multiple candidate answers + queries = [ + ( + "What is 45 + 37?", + [ + "82", # Correct + "45+37=82", # Correct with work + "83", # Off by one + "72", # Wrong + "Eighty-two", # Correct, word form + ], + ), + ( + "Calculate 9 * 8", + [ + "72", # Correct + "73", # Close + "81", # 9*9, common mistake + "9*8=72", # Correct with work + "64", # 8*8, wrong + ], + ), + ( + "What is 100 - 45?", + [ + "55", # Correct + "100-45=55", # Correct with work + "65", # Wrong + "Fifty-five", # Correct, word form + "54", # Off by one + ], + ), + ] + + for i in range(min(num_queries, len(queries))): + query, candidates = queries[i] + + # Create task + task = Task( + main_query=query, + task_id=f"listwise_test_{i}", + metadata={}, # No reference needed for evaluation + ) + + workflow_task = WorkflowTask( + task_id=f"listwise_test_{i}", + task=task, + ) + + # Create output for each candidate + candidate_outputs = [] + for j, candidate_answer in enumerate(candidates): + output = WorkflowOutput(metadata={"final_answer": candidate_answer}) + candidate_outputs.append(output) + + test_data.append((workflow_task, candidate_outputs)) + + return test_data + + +def create_math_test_samples(num_samples: int = 5) -> List[tuple[WorkflowTask, WorkflowOutput]]: + """Create test samples (task + output pairs) for evaluation.""" + samples = [] + + # Test problems with model outputs (some correct, some incorrect) + test_cases = [ + ("What is 25 + 38?", "63", "63", True), + ("Calculate 12 * 5", "60", "60", True), + ("What is 90 - 45?", "45", "45", True), + ("Find the value of 64 / 8", "8", "7", False), # Wrong answer + ("What is 3^4?", "81", "64", False), # Wrong answer + ("Calculate 18 + 22", "40", "40", True), + ("What is 9 * 7?", "63", "56", False), # Wrong answer + ("Find the value of (10 + 5) * 3", "45", "45", True), + ("What is 77 - 33?", "44", "44", True), + ("Calculate 20 / 5 + 15", "19", "19", True), + ("What is 6 * 8 - 10?", "38", "40", False), # Wrong answer + ("Find the value of 55 + 45", "100", "100", True), + ("What is 100 - 25 - 25?", "50", "50", True), + ("Calculate 7^2", "49", "49", True), + ("What is (20 - 5) / 3?", "5", "7", False), # Wrong answer + ] + + for i in range(min(num_samples, len(test_cases))): + query, reference, model_output, _ = test_cases[i] + + task = Task(main_query=query, task_id=f"test_sample_{i}", metadata={"answer": reference}) + + workflow_task = WorkflowTask( + task_id=f"test_sample_{i}", + task=task, + ) + + workflow_output = WorkflowOutput(metadata={"final_answer": model_output}) + + samples.append((workflow_task, workflow_output)) + + return samples + + +# ============================================ +# Main Entry Point +# ============================================ + + +async def main(): + """Run all examples.""" + + # Run examples + try: + await example_pregerated_rubrics() + await example_listwise_mode() + + except Exception as e: + print(f"\n✗ Error running examples: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + # Run the examples + asyncio.run(main()) diff --git a/tutorial/example_rm_auto_grader/config.md b/tutorial/example_rm_auto_grader/config.md new file mode 100644 index 00000000..9a93b8c4 --- /dev/null +++ b/tutorial/example_rm_auto_grader/config.md @@ -0,0 +1,135 @@ +# Example configuration for RM Auto Grader Judge +# This configuration integrates RM Gallery's AutoGrader capabilities into ajet + +ajet + # Task judge configuration using RM Auto Grader + task_judge: + # Use AutoGraderJudge for pre-generated rubrics + # Or use RMAutoGraderBatchJudge for online rubric generation + + # ======================================== + # Model Configuration + # ======================================== + # LLM model for rubric generation and evaluation + model_name: "qwen3-32b" # Options: qwen-plus, qwen-max, gpt-4, gpt-3.5-turbo, etc. + + # ======================================== + # Grader Mode Configuration + # ======================================== + grader_mode: "pointwise" # Options: "pointwise" or "listwise" + language: "en" # Options: "en" or "zh" + + # Score range for pointwise evaluation + min_score: 0 + max_score: 10 + + # Success threshold (0.0 - 1.0) - what normalized score counts as success + success_threshold: 0.7 + + # ======================================== + # Rubric Generation Configuration + # ======================================== + # Sampling mode for rubric generation + sampling_mode: "all_samples" # Options: "all_samples" or "smart_sampling" + + # Number of rubrics to generate per sample + generate_number: 3 + + # Maximum epochs for iterative refinement + max_epochs: 3 + + # Maximum retry attempts for LLM API calls + max_retries: 5 + + # Batch processing settings (for smart_sampling mode) + batch_size: 10 + mcr_batch_size: 10 + + # Aggregation mode for final rubrics + aggregation_mode: "keep_all" # Options: "keep_all" or "merge_similar" + + # ======================================== + # Reference Samples Configuration + # ======================================== + # Path to reference samples (for pre-generating rubrics) + # reference_samples_path: "data/reference_samples.jsonl" + + # Number of reference samples to use + num_reference_samples: 20 + + # ======================================== + # Field Mapping Configuration + # ======================================== + # Field names for extracting data from WorkflowTask and WorkflowOutput + query_field: "main_query" # Field in task containing the query + answer_field: "final_answer" # Field in output metadata containing the answer + reference_field: "answer" # Field in task.metadata containing reference answer + + # ======================================== + # Grader name for logging + # ======================================== + grader_name: "RM Auto Grader" + + +# ============================================ +# Batch Judge Specific Configuration +# ============================================ +# Uncomment and use these settings when using RMAutoGraderBatchJudge + +# ajet +# task_judge: +# class_name: RMAutoGraderBatchJudge +# +# # ... (include all settings from above) +# +# # Warmup phase settings +# warmup_samples: 20 # Collect N samples before generating rubrics +# +# # Regeneration settings +# regenerate_interval: 0 # Regenerate rubrics every N evaluations (0 = never) + + +# ============================================ +# Example for Math Problem Evaluation +# ============================================ +# ajet +# task_judge: +# class_name: AutoGraderJudge +# model_name: "qwen-plus" +# grader_mode: "pointwise" +# language: "en" +# min_score: 0 +# max_score: 10 +# success_threshold: 0.8 +# sampling_mode: "all_samples" +# generate_number: 5 +# max_epochs: 3 +# aggregation_mode: "merge_similar" +# num_reference_samples: 30 +# query_field: "main_query" +# answer_field: "final_answer" +# reference_field: "answer" + + +# ============================================ +# Example for Agent Task Evaluation +# ============================================ +# ajet +# task_judge: +# class_name: AutoGraderJudge +# model_name: "gpt-4" +# grader_mode: "pointwise" +# language: "en" +# min_score: 0 +# max_score: 100 +# success_threshold: 0.7 +# sampling_mode: "smart_sampling" +# generate_number: 3 +# max_epochs: 2 +# batch_size: 15 +# mcr_batch_size: 10 +# aggregation_mode: "keep_all" +# num_reference_samples: 50 +# query_field: "main_query" +# answer_field: "agent_output" +# reference_field: "expected_outcome" diff --git a/tutorial/example_rm_auto_grader/rubrics_train.jsonl b/tutorial/example_rm_auto_grader/rubrics_train.jsonl new file mode 100644 index 00000000..8778c5f8 --- /dev/null +++ b/tutorial/example_rm_auto_grader/rubrics_train.jsonl @@ -0,0 +1,10 @@ +{"main_query": "What is 15 + 27?", "answer": "42", "score": 1} +{"main_query": "Calculate 8 * 9", "answer": "72", "score": 1} +{"main_query": "What is 100 - 37?", "answer": "63", "score": 1} +{"main_query": "Find the value of 144 / 12", "answer": "12", "score": 1} +{"main_query": "What is 5^3?", "answer": "125", "score": 1} +{"main_query": "Calculate 23 + 45 - 18", "answer": "50", "score": 1} +{"main_query": "What is 7 * 8 + 6?", "answer": "62", "score": 1} +{"main_query": "Find the value of (15 + 5) * 2", "answer": "40", "score": 1} +{"main_query": "What is 99 - 33 - 22?", "answer": "44", "score": 1} +{"main_query": "Calculate 16 / 4 + 10", "answer": "14", "score": 1} diff --git a/tutorial/math_agent.md b/tutorial/example_rubrics_judge/math_agent.md similarity index 82% rename from tutorial/math_agent.md rename to tutorial/example_rubrics_judge/math_agent.md index 5649622f..0d1c74b4 100644 --- a/tutorial/math_agent.md +++ b/tutorial/example_rubrics_judge/math_agent.md @@ -45,7 +45,7 @@ self.toolkit.register_tool_function(execute_python_code) self.agent = ReActAgent( name="math_react_agent", sys_prompt=system_prompt, - model=beyondagent_proxy, # type: ignore + model=ajet_proxy, # type: ignore formatter=DashScopeChatFormatter(), toolkit=self.toolkit, memory=InMemoryMemory(), @@ -57,21 +57,21 @@ result = await self.agent.reply(msg, structured_model=FinalResult) - 在 AgentScope Workflow 中,注册评价函数需要的任意关键数据 ```python -beyondagent_proxy.update_judge_input_dictionary(final_answer=final_answer) +ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) ``` ### 3. 准备Judge (奖励模块) -在 astune/task_judge/math_answer_as_judge.py 中,提供了两个简单的Judge。可以在项目任意地方新建新的Judge代码 +在 tutorial/example_math_agent/math_answer_as_judge.py 中,提供了两个简单的Judge。可以在项目任意地方新建新的Judge代码 Judge的输入参数包含: ```python judge_input_dictionary['env']: env_service 外部环境 (如果使用了env_service) -judge_input_dictionary['task_core_arg']: 任务信息(如果里面包含了参考答案,可以从中取出) +judge_input_dictionary['workflow_task']: 任务信息(如果里面包含了参考答案,可以从中取出) judge_input_dictionary['grouped_steps']: LLM的每一次历史对话记录(如果中间过程比较重要,可以从中取出) -judge_input_dictionary['final_answer']: 默认没有final_answer,需要在agentscope workflow中手动调用 beyondagent_proxy.update_judge_input_dictionary(final_answer=final_answer) 注册 +judge_input_dictionary['final_answer']: 默认没有final_answer,需要在agentscope workflow中手动调用 ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) 注册 ``` Judge的返回值: raw_reward, is_success @@ -80,38 +80,37 @@ Judge的返回值: raw_reward, is_success ### 4. 测试 -4.1 复制并修改 [launcher/math_agent/git-math-agentscope.yaml](../launcher/math_agent/git-math-agentscope.yaml) 中的关键参数,yaml中与本文档最相关的部分已经用✨✨✨✨符号标记 +4.1 复制并修改 [tutorial/example_rubrics_judge/math_agent.yaml](../tutorial/example_rubrics_judge/math_agent.yaml) 中的关键参数,yaml中与本文档最相关的部分已经用✨✨✨✨符号标记 -1. 读取task(对应配置字段 astune.task_reader) -2. 定义 AgentScopeWorkflow(对应配置字段 astune.rollout.agentscope_learn_protocol ) +1. 读取task(对应配置字段 ajet.task_reader) +2. 定义 Workflow(对应配置字段 ajet.rollout.user_workflow ) - 举例如果 agentscope workflow 定义在 `tutorial/math_agent.py` 的`ExampleMathLear` 类 - - 则填写 astune.rollout.agentscope_learn_protocol=`tutorial.math_agent->ExampleMathLearn` -3. 定义评分函数(对应配置字段 astune.task_judge.judge_protocol ) - - 举例如果 agentscope workflow 定义在 `astune/task_judge/math_answer_as_judge.py` 的`MathAnswerAndLlmAsJudge` 类 - - 则填写 astune.task_judge.judge_protocol=`astune.task_judge.math_answer_as_judge->MathAnswerAndLlmAsJudge` -4. 指定模型(对应配置字段 astune.model.path ) + - 则填写 ajet.rollout.user_workflow=`tutorial.math_agent->ExampleMathLearn` +3. 定义评分函数(对应配置字段 ajet.task_judge.judge_protocol ) + - 举例如果 agentscope workflow 定义在 `tutorial/example_math_agent/math_answer_as_judge.py` 的`MathAnswerAndLlmAsJudge` 类 + - 则填写 ajet.task_judge.judge_protocol=`tutorial.example_math_agent.math_answer_as_judge->MathAnswerAndLlmAsJudge` +4. 指定模型(对应配置字段 ajet.model.path ) ```yaml -astune: +ajet task_reader: type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` rollout: - use_agentscope_protocol: True - agentscope_learn_protocol: tutorial.math_agent->ExampleMathLearn # ✨✨✨✨ 编写并选择Agent + user_workflow: tutorial.math_agent->ExampleMathLearn # ✨✨✨✨ 编写并选择Agent task_judge: # ✨✨✨✨ 编写并选择评价函数 - judge_protocol: astune.task_judge.math_answer_as_judge->MathAnswerAndLlmAsJudge + judge_protocol: tutorial.example_math_agent.math_answer_as_judge->MathAnswerAndLlmAsJudge model: # ✨✨✨✨ 设置待训练的模型 - path: /mnt/data/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + path: YOUR_MODEL_PATH ``` 4.2 全链路调试(脱离ray快速调试:--backbone='debug') ```bash -# (训练math agent demo)建议开始前杀死所有ray、env_service进程 ( python launcher.py --kill="python|ray" ) +# (训练math agent demo)建议开始前杀死所有ray、env_service进程 ( ajet --kill="python|ray" ) clear && \ -python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --backbone='debug' --with-logview +ajet --conf tutorial/example_rubrics_judge/math_agent.yaml --backbone='debug' --with-logview ``` 备注:当--backbone=debug时,程序不再使用ray,可以编写vscode的launch.json进行便捷的断点调试,launch.json的配置: ```json @@ -123,7 +122,7 @@ python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --backbon "name": "Python Debugger: Launch rollout", "type": "debugpy", "request": "launch", - "program": "launcher.py", + "program": "ajet/cli/launcher.py", "console": "integratedTerminal", "args": [ "--backbone", "debug", @@ -137,10 +136,10 @@ python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --backbon ``` -4.3 当调试完成后,开始训练(只需要把backbone切换一下即可:--backbone='verl') +4.3 当调试完成后,开始训练(只需要把backbone切换一下即可:--backbone='trinity') ```bash -# 建议开始前杀死所有ray、vllm、env_service进程 ( python launcher.py --kill="python|ray|vllm" ) -python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --backbone='verl' +# 建议开始前杀死所有ray、vllm、env_service进程 ( ajet --kill="python|ray|vllm" ) +ajet --conf tutorial/example_rubrics_judge/math_agent.yaml --backbone='trinity' --with-ray ``` @@ -150,7 +149,7 @@ python launcher.py --conf launcher/math_agent/git-math-agentscope.yaml --backbon 日志界面 -- 找到日志文件夹,默认在 `./launcher_record/exp_yaml_file_name/*` 下面 +- 找到日志文件夹,默认在 `./saved_experiments/exp_yaml_file_name/*` 下面 - 运行 `beast_logger_go` 启动日志浏览器,vscode端口映射8181端口 ```bash root@xxxx:/xxx/xxx/xxx# beast_logger_go @@ -160,12 +159,18 @@ INFO: Application startup complete. INFO: Uvicorn running on http://127.0.0.1:8181 (Press CTRL+C to quit) ``` - 打开 http://127.0.0.1:8181,提示输入日志文件路径,填写日志文件夹的**绝对路径**,以下形式皆可 - - /mnt/data/qingxu.fu/astune/astune/launcher_record - - /mnt/data/qingxu.fu/astune/astune/launcher_record/exp_yaml_file_name - - /mnt/data/qingxu.fu/astune/astune/launcher_record/exp_yaml_file_name/2025_11_10_02_52/rollout + - /ajet/ajet/saved_experiments + - /ajet/ajet/saved_experiments/exp_yaml_file_name + - /ajet/ajet/saved_experiments/exp_yaml_file_name/2025_11_10_02_52/rollout - 依次打开界面 **左侧** 的日志文件目标,**中间** 的日志条目,**右侧** 的交互记录,即可显示完整的轨迹 - 蓝色 Token 代表参与loss计算的 Token,黄色反之 -- 鼠标悬浮在 Token 上面可以查看 Token 的 **logprob** (暂时仅限trinity backbone) \ No newline at end of file +- 鼠标悬浮在 Token 上面可以查看 Token 的 **logprob** (暂时仅限trinity backbone) + + +### 6. 参考训练曲线 + +点击链接打开训练曲线: +https://swanlab.cn/@binaryhusky/public/runs/96arcunrxlezdmcvmcdob/chart diff --git a/tutorial/example_rubrics_judge/math_agent.py b/tutorial/example_rubrics_judge/math_agent.py new file mode 100644 index 00000000..f035ac57 --- /dev/null +++ b/tutorial/example_rubrics_judge/math_agent.py @@ -0,0 +1,59 @@ +from agentscope.message import Msg +from loguru import logger +from pydantic import BaseModel, Field + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask + + +def extract_final_answer(result) -> str: + """Extract the final answer from the agent's response.""" + try: + if hasattr(result, "metadata") and isinstance(result.metadata, dict) and "result" in result.metadata: + return result.metadata["result"] + if hasattr(result, "content"): + if isinstance(result.content, dict) and "result" in result.content: + return result.content["result"] + return str(result.content) + return str(result) + except Exception as e: + logger.warning(f"Extract final answer error: {e}. Raw: {result}") + return str(result) + + +class FinalResult(BaseModel): + result: str = Field(description="Your solution of the given math problem. Put your final answer in boxed format, e.g., \\boxed{42}") + + +system_prompt = """ +You are an agent specialized in solving math problems with tools. +Please solve the math problem given to you. +You can write and execute Python code to perform calculation or verify your answer. +You should return your final answer within \\boxed{{}}. +""" + + +class ExampleMathLearn(Workflow): + name: str = "math_agent_workflow" + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + from agentscope.agent import ReActAgent + from agentscope.formatter import DashScopeChatFormatter + from agentscope.memory import InMemoryMemory + from agentscope.tool import Toolkit, execute_python_code + + query = workflow_task.task.main_query + self.toolkit = Toolkit() + self.toolkit.register_tool_function(execute_python_code) + self.agent = ReActAgent( + name="math_react_agent", + sys_prompt=system_prompt, + model=tuner.as_agentscope_model(), + formatter=DashScopeChatFormatter(), + toolkit=self.toolkit, + memory=InMemoryMemory(), + ) + self.agent.set_console_output_enabled(False) + msg = Msg("user", query, role="user") + result = await self.agent.reply(msg, structured_model=FinalResult) + final_answer = extract_final_answer(result) + return WorkflowOutput(reward=None, metadata={"final_answer": final_answer}) diff --git a/tutorial/example_rubrics_judge/r_judge.yaml b/tutorial/example_rubrics_judge/r_judge.yaml new file mode 100644 index 00000000..f7a171ec --- /dev/null +++ b/tutorial/example_rubrics_judge/r_judge.yaml @@ -0,0 +1,72 @@ +# ------------------ 主要配置 ------------------ +ajet: + project_name: example_rubrics_judge + task_reader: + type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` + # 如果选择 `huggingface_dat_repo` 以下配置生效 + huggingface_dat_repo: + dataset_path: '/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main' + training_split: "train" + validation_split: "test" + + task_judge: + # ✨✨✨✨ 编写并选择评价函数 + judge_type: rubrics_auto_grader # Options: 'customized_protocol', 'rubrics_auto_grader' + rubrics_auto_grader: + # rubrics begin + model_name: qwen-max + grader_mode: pointwise + language: en + min_score: 0 + max_score: 1 + success_threshold: 0.7 + sampling_mode: all_samples + generate_number: 1 + max_epochs: 2 + max_retries: 3 + aggregation_mode: keep_all + grader_name: Math Auto Grader + num_reference_samples: 20 + query_field: main_query + answer_field: final_answer + reference_field: answer + input_data_type: dataset_file # `env_service` or `dataset_file` or `huggingface_dat_repo` + dataset_file: + training: + file_path: "tutorial/example_rm_auto_grader/rubrics_train.jsonl" + + + model: + # ✨✨✨✨ 设置待训练的模型 + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + + rollout: + user_workflow: "tutorial.example_rubrics_judge.math_agent->ExampleMathLearn" # ✨✨✨✨ 编写并选择Agent + temperature: 0.7 + max_env_worker: 80 + num_repeat: 4 + agent_madness_reward: 0.0 + + data: + train_batch_size: 64 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + + + +# ------------------ 不需要修改 ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ 不需要修改 ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/example_werewolves/game.py b/tutorial/example_werewolves/game.py new file mode 100644 index 00000000..7f72f9f9 --- /dev/null +++ b/tutorial/example_werewolves/game.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# pylint: disable=too-many-branches, too-many-statements, no-name-in-module +"""A werewolf game implemented by agentscope.""" +from agentscope.agent import ReActAgent +from agentscope.pipeline import MsgHub, fanout_pipeline, sequential_pipeline + +# Uncomment the following line to use Chinese prompts +# from tutorial.example_werewolves.prompt import ChinesePrompts as Prompts +from loguru import logger + +from tutorial.example_werewolves.prompt import EnglishPrompts as Prompts +from tutorial.example_werewolves.structured_model import ( + DiscussionModel, + WitchResurrectModel, + get_hunter_model, + get_poison_model, + get_seer_model, + get_vote_model, +) +from tutorial.example_werewolves.utils import ( + MAX_DISCUSSION_ROUND, + MAX_GAME_ROUND, + EchoAgent, + Players, + majority_vote, + names_to_str, +) + + +class BadGuyException(Exception): + ... + + +moderator = EchoAgent() +# moderator.set_console_output_enabled(False) + + +async def hunter_stage( + hunter_agent: ReActAgent, + players: Players, +) -> str | None: + """Because the hunter's stage may happen in two places: killed at night + or voted during the day, we define a function here to avoid duplication.""" + global moderator + msg_hunter = await hunter_agent( + await moderator(Prompts.to_hunter.format(name=hunter_agent.name)), + structured_model=get_hunter_model(players.current_alive), + ) + if msg_hunter.metadata.get("shoot"): + return msg_hunter.metadata.get("name", None) + return None + + +async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C901 + """The main entry of the werewolf game + + Args: + agents (`list[ReActAgent]`): + A list of 9 agents. + """ + assert len(agents) == 9, "The werewolf game needs exactly 9 players." + + # Init the players' status + players = Players() + + # If the witch has healing and poison potion + healing, poison = True, True + + # If it's the first day, the dead can leave a message + first_day = True + + # Broadcast the game begin message + async with MsgHub(participants=agents) as greeting_hub: + await greeting_hub.broadcast( + await moderator( + Prompts.to_all_new_game.format(names_to_str(agents)), + ), + ) + + # Assign roles to the agents + for agent, role in zip(agents, roles): + # Tell the agent its role + await agent.observe( + await moderator( + f"[{agent.name} ONLY] {agent.name}, your role is {role}.", + ), + ) + players.add_player(agent, role) + + # Printing the roles + players.print_roles() + + # GAME BEGIN! + for _ in range(MAX_GAME_ROUND): + # Create a MsgHub for all players to broadcast messages + async with MsgHub( + participants=players.current_alive, + enable_auto_broadcast=False, # manual broadcast only + name="alive_players", + ) as alive_players_hub: + # Night phase + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_night), + ) + killed_player, poisoned_player, shot_player = None, None, None + + try: + # Werewolves discuss + async with MsgHub( + players.werewolves, + enable_auto_broadcast=True, + announcement=await moderator( + Prompts.to_wolves_discussion.format( + names_to_str(players.werewolves), + names_to_str(players.current_alive), + ), + ), + name="werewolves", + ) as werewolves_hub: + # Discussion + n_werewolves = len(players.werewolves) + for _ in range(1, MAX_DISCUSSION_ROUND * n_werewolves + 1): + res = await players.werewolves[_ % n_werewolves]( + structured_model=DiscussionModel, + ) + if _ % n_werewolves == 0 and res.metadata.get( + "reach_agreement", + ): + break + + # Werewolves vote + # Disable auto broadcast to avoid following other's votes + werewolves_hub.set_auto_broadcast(False) + msgs_vote = await fanout_pipeline( + players.werewolves, + msg=await moderator(content=Prompts.to_wolves_vote), + structured_model=get_vote_model(players.current_alive), + enable_gather=False, + ) + killed_player, votes = majority_vote( + [_.metadata.get("vote") for _ in msgs_vote], + ) + # Postpone the broadcast of voting + await werewolves_hub.broadcast( + [ + *msgs_vote, + await moderator( + Prompts.to_wolves_res.format(votes, killed_player), + ), + ], + ) + except Exception as e: + raise BadGuyException( + f"Werewolves failed to make a decision: {e}", + ) + + # Witch's turn + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_witch_turn), + ) + msg_witch_poison = None + for agent in players.witch: + # Cannot heal witch herself + msg_witch_resurrect = None + if healing and killed_player != agent.name: + msg_witch_resurrect = await agent( + await moderator( + Prompts.to_witch_resurrect.format( + witch_name=agent.name, + dead_name=killed_player, + ), + ), + structured_model=WitchResurrectModel, + ) + if msg_witch_resurrect.metadata.get("resurrect"): + killed_player = None + healing = False + + # Has poison potion and hasn't used the healing potion + if poison and not (msg_witch_resurrect and msg_witch_resurrect.metadata["resurrect"]): + msg_witch_poison = await agent( + await moderator( + Prompts.to_witch_poison.format( + witch_name=agent.name, + ), + ), + structured_model=get_poison_model( + players.current_alive, + ), + ) + if msg_witch_poison.metadata.get("poison"): + poisoned_player = msg_witch_poison.metadata.get("name") + poison = False + + # Seer's turn + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_seer_turn), + ) + for agent in players.seer: + msg_seer = await agent( + await moderator( + Prompts.to_seer.format( + agent.name, + names_to_str(players.current_alive), + ), + ), + structured_model=get_seer_model(players.current_alive), + ) + if msg_seer.metadata.get("name"): + player = msg_seer.metadata["name"] + await agent.observe( + await moderator( + Prompts.to_seer_result.format( + agent_name=player, + role=players.name_to_role[player], + ), + ), + ) + + # Hunter's turn + for agent in players.hunter: + # If killed and not by witch's poison + if killed_player == agent.name and poisoned_player != agent.name: + shot_player = await hunter_stage(agent, players) + + # Update alive players + dead_tonight = [killed_player, poisoned_player, shot_player] + players.update_players(dead_tonight) + + # Day phase + if len([_ for _ in dead_tonight if _]) > 0: + await alive_players_hub.broadcast( + await moderator( + Prompts.to_all_day.format( + names_to_str([_ for _ in dead_tonight if _]), + ), + ), + ) + + # The killed player leave a last message in first night + if killed_player and first_day: + msg_moderator = await moderator( + Prompts.to_dead_player.format(killed_player), + ) + await alive_players_hub.broadcast(msg_moderator) + # Leave a message + last_msg = await players.name_to_agent[killed_player]() + await alive_players_hub.broadcast(last_msg) + + else: + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_peace), + ) + + # Check winning + res = players.check_winning() + if res: + await moderator(res) + break + + # Discussion + await alive_players_hub.broadcast( + await moderator( + Prompts.to_all_discuss.format( + names=names_to_str(players.current_alive), + ), + ), + ) + # Open the auto broadcast to enable discussion + alive_players_hub.set_auto_broadcast(True) + await sequential_pipeline(players.current_alive) + # Disable auto broadcast to avoid leaking info + alive_players_hub.set_auto_broadcast(False) + + # Voting + msgs_vote = await fanout_pipeline( + players.current_alive, + await moderator( + Prompts.to_all_vote.format( + names_to_str(players.current_alive), + ), + ), + structured_model=get_vote_model(players.current_alive), + enable_gather=False, + ) + voted_player, votes = majority_vote( + [_.metadata.get("vote") for _ in msgs_vote], + ) + # Broadcast the voting messages together to avoid influencing + # each other + voting_msgs = [ + *msgs_vote, + await moderator( + Prompts.to_all_res.format(votes, voted_player), + ), + ] + + # Leave a message if voted + if voted_player: + prompt_msg = await moderator( + Prompts.to_dead_player.format(voted_player), + ) + last_msg = await players.name_to_agent[voted_player]( + prompt_msg, + ) + voting_msgs.extend([prompt_msg, last_msg]) + + await alive_players_hub.broadcast(voting_msgs) + + # If the voted player is the hunter, he can shoot someone + shot_player = None + for agent in players.hunter: + if voted_player == agent.name: + shot_player = await hunter_stage(agent, players) + if shot_player: + await alive_players_hub.broadcast( + await moderator( + Prompts.to_all_hunter_shoot.format( + shot_player, + ), + ), + ) + + # Update alive players + dead_today = [voted_player, shot_player] + players.update_players(dead_today) + + # Check winning + res = players.check_winning() + if res: + async with MsgHub(players.all_players) as all_players_hub: + res_msg = await moderator(res) + await all_players_hub.broadcast(res_msg) + break + + # The day ends + first_day = False + + # # Game over, each player reflects + # await fanout_pipeline( + # agents=agents, + # msg=await moderator(Prompts.to_all_reflect), + # ) + + alive_wolves = players.werewolves + good_guy_win = len(alive_wolves) == 0 + logger.warning("**********************************") + logger.warning(f"Good guy win: {good_guy_win}, alive werewolves: {alive_wolves}") + return good_guy_win diff --git a/tutorial/example_werewolves/prompt.py b/tutorial/example_werewolves/prompt.py new file mode 100644 index 00000000..d04306f0 --- /dev/null +++ b/tutorial/example_werewolves/prompt.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +"""Default prompts""" + + +class EnglishPrompts: + """English prompts used to guide the werewolf game.""" + + to_dead_player = "{}, you're eliminated now. Now you can make a final statement to " "all alive players before you leave the game." + + to_all_new_game = "A new game is starting, the players are: {}. Now we randomly " "reassign the roles to each player and inform them of their roles " "privately." + + to_all_night = "Night has fallen, everyone close your eyes. Werewolves open your " "eyes and choose a player to eliminate tonight." + + to_wolves_discussion = "[WEREWOLVES ONLY] {}, you should discuss and " "decide on a player to eliminate tonight. Current alive players " "are {}. Remember to set `reach_agreement` to True if you reach an " "agreement during the discussion." + + to_wolves_vote = "[WEREWOLVES ONLY] Which player do you vote to kill?" + + to_wolves_res = "[WEREWOLVES ONLY] The voting result is {}. So you have chosen to " "eliminate {}." + + to_all_witch_turn = "Witch's turn, witch open your eyes and decide your action tonight..." + to_witch_resurrect = "[WITCH ONLY] {witch_name}, you're the witch, and tonight {dead_name} " "is eliminated. You can resurrect him/her by using your healing " "potion, " "and note you can only use it once in the whole game. Do you want to " "resurrect {dead_name}? Give me your reason and decision." + + to_witch_resurrect_no = "[WITCH ONLY] The witch has chosen not to resurrect the player." + to_witch_resurrect_yes = "[WITCH ONLY] The witch has chosen to resurrect the player." + + to_witch_poison = "[WITCH ONLY] {witch_name}, as a witch, you have a one-time-use " "poison potion, do you want to use it tonight? Give me your reason " "and decision." + + to_all_seer_turn = "Seer's turn, seer open your eyes and check one player's identity " "tonight..." + + to_seer = "[SEER ONLY] {}, as the seer you can check one player's identity " "tonight. Who do you want to check? Give me your reason and decision." + + to_seer_result = "[SEER ONLY] You've checked {agent_name}, and the result is: {role}." + + to_hunter = "[HUNTER ONLY] {name}, as the hunter you're eliminated tonight. You " "can choose one player to take down with you. Also, you can choose " "not to use this ability. Give me your reason and decision." + + to_all_hunter_shoot = "The hunter has chosen to shoot {} down with him/herself." + + to_all_day = "The day is coming, all players open your eyes. Last night, " "the following player(s) has been eliminated: {}." + + to_all_peace = "The day is coming, all the players open your eyes. Last night is " "peaceful, no player is eliminated." + + to_all_discuss = "Now the alive players are {names}. The game goes on, it's time to " "discuss and vote a player to be eliminated. Now you each take turns " "to speak once in the order of {names}." + + to_all_vote = "Now the discussion is over. Everyone, please vote to eliminate one " "player from the alive players: {}." + + to_all_res = "The voting result is {}. So {} has been voted out." + + to_all_wolf_win = "There are {n_alive} players alive, and {n_werewolves} of them are " "werewolves. " "The game is over and werewolves win🐺🎉!" "In this game, the true roles of all players are: {true_roles}" + + to_all_village_win = "All the werewolves have been eliminated." "The game is over and villagers win🏘️🎉!" "In this game, the true roles of all players are: {true_roles}" + + to_all_continue = "The game goes on." + + to_all_reflect = "The game is over. Now each player can reflect on their performance. " "Note each player only has one chance to speak and the reflection is " "only visible to themselves." + + +class ChinesePrompts: + """Chinese prompts used to guide the werewolf game.""" + + to_dead_player = "{}, 你已被淘汰。现在你可以向所有存活玩家发表最后的遗言。" + + to_all_new_game = "新的一局游戏开始,参与玩家包括:{}。现在为每位玩家重新随机分配身份,并私下告知各自身份。" + + to_all_night = "天黑了,请所有人闭眼。狼人请睁眼,选择今晚要淘汰的一名玩家..." + + to_wolves_discussion = "[仅狼人可见] {}, 你们可以讨论并决定今晚要淘汰的玩家。当前存活玩家有:{}。" "如果达成一致,请将 `reach_agreement` 设为 True。" + + to_wolves_vote = "[仅狼人可见] 你投票要杀死哪位玩家?" + + to_wolves_res = "[仅狼人可见] 投票结果为 {},你们选择淘汰 {}。" + + to_all_witch_turn = "轮到女巫行动,女巫请睁眼并决定今晚的操作..." + to_witch_resurrect = "[仅女巫可见] {witch_name},你是女巫,今晚{dead_name}被淘汰。" "你可以用解药救他/她,注意解药全局只能用一次。你要救{dead_name}吗?" "请给出理由和决定。" + + to_witch_resurrect_no = "[仅女巫可见] 女巫选择不救该玩家。" + to_witch_resurrect_yes = "[仅女巫可见] 女巫选择救活该玩家。" + + to_witch_poison = "[仅女巫可见] {witch_name},你有一瓶一次性毒药,今晚要使用吗?请给出理由和决定。" + + to_all_seer_turn = "轮到预言家行动,预言家请睁眼并查验一名玩家身份..." + + to_seer = "[仅预言家可见] {}, 你是预言家,今晚可以查验一名玩家身份。你要查谁?请给出理由和决定。" + + to_seer_result = "[仅预言家可见] 你查验了{agent_name},结果是:{role}。" + + to_hunter = "[仅猎人可见] {name},你是猎人,今晚被淘汰。你可以选择带走一名玩家,也可以选择不带走。请给出理由和决定。" + + to_all_hunter_shoot = "猎人选择带走 {} 一起出局。" + + to_all_day = "天亮了,请所有玩家睁眼。昨晚被淘汰的玩家有:{}。" + + to_all_peace = "天亮了,请所有玩家睁眼。昨晚平安夜,无人被淘汰。" + + to_all_discuss = "现在存活玩家有:{names}。游戏继续,大家开始讨论并投票淘汰一名玩家。请按顺序({names})依次发言。" + + to_all_vote = "讨论结束。请大家从存活玩家中投票淘汰一人:{}。" + + to_all_res = "投票结果为 {},{} 被淘汰。" + + to_all_wolf_win = "当前存活玩家共{n_alive}人,其中{n_werewolves}人为狼人。" "游戏结束,狼人获胜🐺🎉!" "本局所有玩家真实身份为:{true_roles}" + + to_all_village_win = "所有狼人已被淘汰。游戏结束,村民获胜🏘️🎉!本局所有玩家真实身份为:{true_roles}" + + to_all_continue = "游戏继续。" + + to_all_reflect = "游戏结束。现在每位玩家可以对自己的表现进行反思。注意每位玩家只有一次发言机会,且反思内容仅自己可见。" diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py new file mode 100644 index 00000000..f9d27a53 --- /dev/null +++ b/tutorial/example_werewolves/start.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# flake8: noqa: E501 + +"""The main entry point for the werewolf game.""" + +from typing import List +import numpy as np +import dotenv + +dotenv.load_dotenv() + +from textwrap import dedent + +from agentscope.agent import ReActAgent +from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter +from agentscope.model import DashScopeChatModel, OpenAIChatModel +from loguru import logger +from pydantic import Field + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from tutorial.example_werewolves.game import BadGuyException, werewolves_game + + +def get_official_agent_prompt(name) -> str: + system_prompt = dedent( + f""" + You're a werewolf game player named {name}. + + # YOUR TARGET + Your target is to win the game with your teammates as much as possible. + + # GAME RULES + - In werewolf game, players are divided into three werewolves, three villagers, one seer, one hunter and one witch. + - Werewolves: kill one player each night, and must hide identity during the day. + - Villagers: ordinary players without special abilities, try to identify and eliminate werewolves. + - Seer: A special villager who can check one player's identity each night. + - Witch: A special villager with two one-time-use potions: a healing potion to save a player from being killed at night, and a poison to eliminate one player at night. + - Hunter: A special villager who can take one player down with them when they are eliminated. + - The game alternates between night and day phases until one side wins: + - Night Phase + - Werewolves choose one victim + - Seer checks one player's identity + - Witch decides whether to use potions + - Moderator announces who died during the night + - Day Phase + - All players discuss and vote to eliminate one suspected player + + # GAME GUIDANCE + - Try your best to win the game with your teammates, tricks, lies, and deception are all allowed, e.g. pretending to be a different role. + - During discussion, don't be political, be direct and to the point. + - The day phase voting provides important clues. For example, the werewolves may vote together, attack the seer, etc. + ## GAME GUIDANCE FOR WEREWOLF + - Seer is your greatest threat, who can check one player's identity each night. Analyze players' speeches, find out the seer and eliminate him/her will greatly increase your chances of winning. + - In the first night, making random choices is common for werewolves since no information is available. + - Pretending to be other roles (seer, witch or villager) is a common strategy to hide your identity and mislead other villagers in the day phase. + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR SEER + - Seer is very important to villagers, exposing yourself too early may lead to being targeted by werewolves. + - Your ability to check one player's identity is crucial. + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR WITCH + - Witch has two powerful potions, use them wisely to protect key villagers or eliminate suspected werewolves. + - The outcome of the night phase provides important clues. For example, if the dead player is hunter, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR HUNTER + - Using your ability in day phase will expose your role (since only hunter can take one player down) + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR VILLAGER + - Protecting special villagers, especially the seer, is crucial for your team's success. + - Werewolves may pretend to be the seer. Be cautious and don't trust anyone easily. + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. + + # NOTE + - [IMPORTANT] DO NOT make up any information that is not provided by the moderator or other players. + - This is a TEXT-based game, so DO NOT use or make up any non-textual information. + - Always critically reflect on whether your evidence exist, and avoid making assumptions. + - Your response should be specific and concise, provide clear reason and avoid unnecessary elaboration. + - Generate your one-line response by using the `generate_response` function. + - Don't repeat the others' speeches.""" + ) + return system_prompt + + +class ExampleWerewolves(Workflow): + trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.") + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + # ensure trainable targets is legal + assert self.trainable_targets is not None, "trainable_targets cannot be None in ExampleWerewolves (because we want to demonstrate a explicit multi-agent case)." + + # bad guys and good guys cannot be trained simultaneously + # (because mix-cooperation-competition MARL needs too many advanced techniques to be displayed here) + if "werewolf" in self.trainable_targets: + assert len(self.trainable_targets) == 1, "Cannot train hostile roles simultaneously." + else: + assert len(self.trainable_targets) != 0, "No trainable targets specified." + + # make and shuffle roles (fix random seed for reproducibility) + roles = ["werewolf"] * 3 + ["villager"] * 3 + ["seer", "witch", "hunter"] + task_id = workflow_task.task.metadata["random_number"] + np.random.seed(int(task_id)) + np.random.shuffle(roles) + + # initialize agents + players = [] + for i, role in enumerate(roles): + default_model = OpenAIChatModel( + model_name="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", + stream=False, + client_args={"base_url": "http://22.17.52.4:2888/v1"}, + api_key="no_api_key", + generate_kwargs={"temperature": 0.01}, + ) + model_for_this_agent = tuner.as_agentscope_model( + agent_name=f"Player{i + 1}", # the name of this agent + target_tag=role, # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent. + debug_model=default_model, # the model used when this agent is not in `self.trainable_targets` + ) + agent = ReActAgent( + name=f"Player{i + 1}", + sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), + model=model_for_this_agent, + formatter=DashScopeMultiAgentFormatter() if role in self.trainable_targets else OpenAIMultiAgentFormatter(), + max_iters=3 if role in self.trainable_targets else 5, + ) + # agent.set_console_output_enabled(False) + players += [agent] + + # reward condition + try: + good_guy_win = await werewolves_game(players, roles) + raw_reward = 0 + is_success = False + if (good_guy_win and self.trainable_targets[0] != "werewolf") or (not good_guy_win and self.trainable_targets[0] == "werewolf"): + raw_reward = 1 + is_success = True + logger.warning(f"Raw reward: {raw_reward}") + logger.warning(f"Is success: {is_success}") + except BadGuyException as e: + logger.bind(exception=True).exception(f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent).") + raw_reward = -0.1 + is_success = False + except Exception as e: + logger.bind(exception=True).exception(f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent).") + raw_reward = -0.1 + is_success = False + + return WorkflowOutput(reward=raw_reward, is_success=is_success) diff --git a/tutorial/example_werewolves/structured_model.py b/tutorial/example_werewolves/structured_model.py new file mode 100644 index 00000000..8e2ddb20 --- /dev/null +++ b/tutorial/example_werewolves/structured_model.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +"""The structured output models used in the werewolf game.""" +from typing import Literal + +from agentscope.agent import AgentBase +from pydantic import BaseModel, Field + + +class DiscussionModel(BaseModel): + """The output format for discussion.""" + + reach_agreement: bool = Field( + description="Whether you have reached an agreement or not", + ) + + +def get_vote_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the vote model by player names.""" + + class VoteModel(BaseModel): + """The vote output format.""" + + vote: Literal[tuple(_.name for _ in agents)] = Field( # type: ignore + description="The name of the player you want to vote for", + ) + + return VoteModel + + +class WitchResurrectModel(BaseModel): + """The output format for witch resurrect action.""" + + resurrect: bool = Field( + description="Whether you want to resurrect the player", + ) + + +def get_poison_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the poison model by player names.""" + + class WitchPoisonModel(BaseModel): + """The output format for witch poison action.""" + + poison: bool = Field( + description="Do you want to use the poison potion", + ) + name: Literal[tuple(_.name for _ in agents)] | None = ( # type: ignore + Field( + description="The name of the player you want to poison, if you " "don't want to poison anyone, just leave it empty", + default=None, + ) + ) + + return WitchPoisonModel + + +def get_seer_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the seer model by player names.""" + + class SeerModel(BaseModel): + """The output format for seer action.""" + + name: Literal[tuple(_.name for _ in agents)] = Field( # type: ignore + description="The name of the player you want to check", + ) + + return SeerModel + + +def get_hunter_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the hunter model by player agents.""" + + class HunterModel(BaseModel): + """The output format for hunter action.""" + + shoot: bool = Field( + description="Whether you want to use the shooting ability or not", + ) + name: Literal[tuple(_.name for _ in agents)] | None = ( # type: ignore + Field( + description="The name of the player you want to shoot, if you " "don't want to the ability, just leave it empty", + default=None, + ) + ) + + return HunterModel diff --git a/tutorial/example_werewolves/utils.py b/tutorial/example_werewolves/utils.py new file mode 100644 index 00000000..7261424e --- /dev/null +++ b/tutorial/example_werewolves/utils.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +"""Utility functions for the werewolf game.""" +from collections import defaultdict +from typing import Any + +import numpy as np +from agentscope.agent import AgentBase, ReActAgent +from agentscope.message import Msg + +from tutorial.example_werewolves.prompt import EnglishPrompts as Prompts + +# MAX_GAME_ROUND = 30 +# MAX_DISCUSSION_ROUND = 3 +MAX_GAME_ROUND = 7 +MAX_DISCUSSION_ROUND = 2 + + +def majority_vote(votes: list[str]) -> tuple: + """Return the vote with the most counts.""" + result = max(set(votes), key=votes.count) + names, counts = np.unique(votes, return_counts=True) + conditions = ", ".join( + [f"{name}: {count}" for name, count in zip(names, counts)], + ) + return result, conditions + + +def names_to_str(agents: list[str] | list[ReActAgent]) -> str: + """Return a string of agent names.""" + if not agents: + return "" + + if len(agents) == 1: + if isinstance(agents[0], ReActAgent): + return agents[0].name + return agents[0] + + names = [] + for agent in agents: + if isinstance(agent, ReActAgent): + names.append(agent.name) + else: + names.append(agent) + return ", ".join([*names[:-1], "and " + names[-1]]) + + +class EchoAgent(AgentBase): + """Echo agent that repeats the input message.""" + + def __init__(self) -> None: + super().__init__() + self.name = "Moderator" + + async def reply(self, content: str) -> Msg: + """Repeat the input content with its name and role.""" + msg = Msg( + self.name, + content, + role="assistant", + ) + await self.print(msg) + return msg + + async def handle_interrupt( + self, + *args: Any, + **kwargs: Any, + ) -> Msg: + """Handle interrupt.""" + + async def observe(self, msg: Msg | list[Msg] | None) -> None: + """Observe the user's message.""" + + +class Players: + """Maintain the players' status.""" + + def __init__(self) -> None: + """Initialize the players.""" + # The mapping from player name to role + self.name_to_role = {} + self.role_to_names = defaultdict(list) + self.name_to_agent = {} + self.werewolves = [] + self.villagers = [] + self.seer = [] + self.hunter = [] + self.witch = [] + self.current_alive = [] + self.all_players = [] + + def add_player(self, player: ReActAgent, role: str) -> None: + """Add a player to the game. + + Args: + player (`ReActAgent`): + The player to be added. + role (`str`): + The role of the player. + """ + self.name_to_role[player.name] = role + self.name_to_agent[player.name] = player + self.role_to_names[role].append(player.name) + self.all_players.append(player) + if role == "werewolf": + self.werewolves.append(player) + elif role == "villager": + self.villagers.append(player) + elif role == "seer": + self.seer.append(player) + elif role == "hunter": + self.hunter.append(player) + elif role == "witch": + self.witch.append(player) + else: + raise ValueError(f"Unknown role: {role}") + self.current_alive.append(player) + + def update_players(self, dead_players: list[ReActAgent]) -> None: + """Update the current alive players. + + Args: + dead_players (`list[ReActAgent]`): + A list of dead players to be removed. + """ + self.werewolves = [_ for _ in self.werewolves if _.name not in dead_players] + self.villagers = [_ for _ in self.villagers if _.name not in dead_players] + self.seer = [_ for _ in self.seer if _.name not in dead_players] + self.hunter = [_ for _ in self.hunter if _.name not in dead_players] + self.witch = [_ for _ in self.witch if _.name not in dead_players] + self.current_alive = [_ for _ in self.current_alive if _.name not in dead_players] + + def print_roles(self) -> None: + """Print the roles of all players.""" + print("Roles:") + for name, role in self.name_to_role.items(): + print(f" - {name}: {role}") + + def check_winning(self) -> str | None: + """Check if the game is over and return the winning message.""" + + # Prepare true roles string + true_roles = f'{names_to_str(self.role_to_names["werewolf"])} are werewolves, ' f'{names_to_str(self.role_to_names["villager"])} are villagers, ' f'{names_to_str(self.role_to_names["seer"])} is the seer, ' f'{names_to_str(self.role_to_names["hunter"])} is the hunter, ' f'and {names_to_str(self.role_to_names["witch"])} is the witch.' + + if len(self.werewolves) * 2 >= len(self.current_alive): + return Prompts.to_all_wolf_win.format( + n_alive=len(self.current_alive), + n_werewolves=len(self.werewolves), + true_roles=true_roles, + ) + if self.current_alive and not self.werewolves: + return Prompts.to_all_village_win.format( + true_roles=true_roles, + ) + return None diff --git a/tutorial/example_werewolves/werewolves.md b/tutorial/example_werewolves/werewolves.md new file mode 100644 index 00000000..e6528d2f --- /dev/null +++ b/tutorial/example_werewolves/werewolves.md @@ -0,0 +1,4 @@ +# Training a basic math agent + + +Please refer to document at [`docs/en/example_werewolves.md`](docs/en/example_werewolves.md) diff --git a/tutorial/example_werewolves/werewolves.yaml b/tutorial/example_werewolves/werewolves.yaml new file mode 100644 index 00000000..ac6d9e72 --- /dev/null +++ b/tutorial/example_werewolves/werewolves.yaml @@ -0,0 +1,76 @@ +# ------------------ main config ------------------ +ajet: + project_name: example_werewolves + task_reader: + type: random_dummy # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo` or `random_dummy` + + task_judge: + # ✨✨✨✨ select evaluation function + judge_protocol: null + + model: + # ✨✨✨✨ select model to be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + + rollout: + # the path to the workflow class + user_workflow: tutorial.example_werewolves.start->ExampleWerewolves # ✨✨✨✨ select agent + + temperature: 0.7 + + max_env_worker: 64 + + num_repeat: 6 + + agent_madness_reward: 0.0 + + tensor_model_parallel_size: 1 + + # the number of vllm engines, number of gpus for infer is `n_vllm_engine*tensor_model_parallel_size`, this argument is NOT effective when NOT using trinity + n_vllm_engine: 8 + + max_num_seqs: 40 + + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: + - "nonsense" + + max_response_length_in_one_turn: 1024 + + max_model_len: 22000 + + multi_turn: + # expected steps for each task run, used to calculate the training batch size for trinity, do not have to be precise + expected_steps: 3 + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 32 + max_prompt_length: 4000 + max_response_length: 18000 + + trainer_common: + save_freq: 5 + test_freq: 99999 + total_epochs: 99999 + total_training_steps: 25 + nnodes: 2 + n_gpus_per_node: 8 + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl # verl only + - file://ajet/default_config/trinity # trinity only + +# ------------------ do not edit ------------------ +defaults: + - verl_default # verl inherit 1/1 + - trinity_default # trinity inherit 1/1 + - ajet_default + - _self_ diff --git a/tutorial/figure/appworld.png b/tutorial/figure/appworld.png new file mode 100644 index 00000000..c0ce5e86 Binary files /dev/null and b/tutorial/figure/appworld.png differ diff --git a/tutorial/figure/werewolves_train_witch.png b/tutorial/figure/werewolves_train_witch.png new file mode 100644 index 00000000..4232e95f Binary files /dev/null and b/tutorial/figure/werewolves_train_witch.png differ diff --git a/tutorial/math_agent.py b/tutorial/math_agent.py deleted file mode 100644 index 084601b1..00000000 --- a/tutorial/math_agent.py +++ /dev/null @@ -1,69 +0,0 @@ -from astune.agentscope_flow import BeyondAgentProxy -from agentscope.message import Msg -from pydantic import BaseModel, Field -from astune.protocol.agentscope_protocol import AgentScopeLearnProtocol -from loguru import logger - -def extract_final_answer(result) -> str: - """Extract the final answer from the agent's response.""" - try: - if ( - hasattr(result, "metadata") - and isinstance(result.metadata, dict) - and "result" in result.metadata - ): - return result.metadata["result"] - if hasattr(result, "content"): - if isinstance(result.content, dict) and "result" in result.content: - return result.content["result"] - return str(result.content) - return str(result) - except Exception as e: - logger.warning(f"Extract final answer error: {e}. Raw: {result}") - return str(result) - - -class FinalResult(BaseModel): - result: str = Field( - description="Your solution of the given math problem. Put your final answer in boxed format, e.g., \\boxed{42}" - ) - -system_prompt = """ -You are an agent specialized in solving math problems with tools. Please solve the math problem given to you. You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}. -""" - -class ExampleMathLearn(AgentScopeLearnProtocol): - - trainer: str = Field(default="agentscorpion-trinity") - - async def agentscope_execute(self, init_messages, beyondagent_proxy: BeyondAgentProxy, config): - from agentscope.agent import ReActAgent - from agentscope.formatter import DashScopeChatFormatter - from agentscope.memory import InMemoryMemory - from agentscope.agent import ReActAgent - from agentscope.memory import InMemoryMemory - from agentscope.tool import Toolkit, execute_python_code - - if len(init_messages) >= 2: first_msg, init_messages = init_messages[0], init_messages[1:] - else: first_msg = {"content": "You're a helpful assistant."} - interaction_message = [] - for msg in init_messages: - interaction_message.append(Msg(name=msg.get("name", "user"), content=msg.get("content", ""), role=msg.get("role", "user"))) - - self.toolkit = Toolkit() - self.toolkit.register_tool_function(execute_python_code) - self.agent = ReActAgent( - name="math_react_agent", - sys_prompt=system_prompt, - model=beyondagent_proxy, # type: ignore - formatter=DashScopeChatFormatter(), - toolkit=self.toolkit, - memory=InMemoryMemory(), - ) - msg = Msg("user", init_messages[0]['content'], role="user") - result = await self.agent.reply(msg, structured_model=FinalResult) - final_answer = extract_final_answer(result) - beyondagent_proxy.update_judge_input_dictionary(final_answer=final_answer) - - return beyondagent_proxy - diff --git a/vsdb.py b/vsdb.py deleted file mode 100644 index 28c11c37..00000000 --- a/vsdb.py +++ /dev/null @@ -1,114 +0,0 @@ -import os - -""" -Ray Distributed Debugger VSCode Extension (Recommended) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -1. Starting with Ray 2.39, Anyscale has introduced the `Ray Distributed Debugger `_ VSCode extension. Follow the extension’s installation instructions, then add your cluster using the dashboard URL you obtained earlier. - - .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/debugger.png?raw=true - :alt: Ray Distributed Debugger VSCode extension screenshot - -2. Prerequisites. - - Ensure the following are installed (see the extension README for more detail): - - - Visual Studio Code - - `ray[default]` >= 2.9.1 - - `debugpy` >= 1.8.0 - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/c7098b755ff689859837773a916c857.png?raw=true - :alt: VSCode with Ray prerequisites - -3. Environment Variables. - - To enable post‑mortem debugging, set: - - .. code-block:: bash - - export RAY_DEBUG_POST_MORTEM=1 - - .. admonition:: Note - :class: important - - Be sure to remove any legacy flags before starting Ray: - - - `RAY_DEBUG=legacy` - - `--ray-debugger-external` - -4. Configuring BreakpointsSet up breakpoint() in your code, and submit job to cluster. Then the extension will show the breakpoint information. - - - 1. Insert `breakpoint()` calls into your remote functions. - 2. Submit your job to the cluster. - - The extension will detect active breakpoints and display them in VSCode. - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true - :alt: Detected breakpoint in VSCode - - **Note:** Breakpoints are only supported inside functions decorated with `@ray.remote`. - -5. Launching the Debugger. - - Run your job directly from the command line (do not use a `launch.json`): - - .. code-block:: bash - - python job.py - -6. Attaching to a Breakpoint. - - Once the process hits the first `breakpoint()`, click the Ray Distributed Debugger icon in the VSCode sidebar to attach the debugger. - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true - :alt: Attaching VSCode debugger to Ray process - -7. Debugging With Multiple breakpoint(). - - For each subsequent task, first disconnect the current debugger session, then click the extension icon again to attach to the next breakpoint. - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/6e83c910a62c82fecb89c6619e001cd.png?raw=true - :alt: Disconnecting and reconnecting the debugger -""" - -def vscode_conditional_breakpoint(tag=None, once=True): - - env_tag = f'HIT_BREAKPOINT_REC_{tag}' - if not os.getenv('RAY_DEBUG_POST_MORTEM'): return - if tag is None: - if once: - if os.getenv(env_tag, "") != "1": - os.environ[env_tag] = "1" - breakpoint() - return - else: - breakpoint() - return - else: - debug_tags = os.getenv('DEBUG_TAGS', '').split('|') - if tag in debug_tags: - if once: - if os.getenv(env_tag, "") != "1": - os.environ[env_tag] = "1" - breakpoint() - return - else: - breakpoint() - return - -import pickle - -def objdump(obj, file="objdump.tmp"): - with open(file, "wb+") as f: - pickle.dump(obj, f) - return - -def objload(file="objdump.tmp"): - import os - if not os.path.exists(file): - return - with open(file, "rb") as f: - return pickle.load(f) - -bp = vscode_conditional_breakpoint \ No newline at end of file