Phần 1: Tìm hiểu PPO cho An ninh mạng - Một hành trình qua Học máy và Hacking có đạo đức

Khám phá cách dạy AI suy nghĩ như một chuyên gia kiểm thử xâm nhập

Một hành trình qua Học máy và Hacking có đạo đức

Giới thiệu: Khi Trí tuệ nhân tạo gặp gỡ An ninh mạng

Nếu bạn đang đọc những dòng này, có lẽ bạn cũng đang tự hỏi - liệu chúng ta có thực sự dạy một chiếc máy tính cách "hack" được không? Không phải theo nghĩa xấu, mà là theo cách các hacker mũ trắng làm việc để giúp hệ thống trở nên an toàn hơn. Đây chính xác là điều mà dự án này cố gắng thực hiện, và thật lòng mà nói - nó vừa hấp dẫn lại vừa có chút choáng ngợp lúc ban đầu.

Dự án mà chúng ta đang khám phá sử dụng một thứ gọi là Tối ưu hóa Chính sách Gần đúng (Proximal Policy Optimization - PPO) để huấn luyện các mô hình ngôn ngữ thực hiện kiểm thử xâm nhập trên một ứng dụng web có lỗ hổng tên là OWASP Juice Shop (Link: https://owasp.org/www-project-juice-shop/). Đừng lo lắng nếu những thuật ngữ này nghe có vẻ đáng sợ - chúng ta sẽ cùng nhau mổ xẻ mọi thứ.

Vậy Chính xác thì Điều gì đang diễn ra ở đây?

Bức tranh toàn cảnh (giải thích đơn giản)

Hãy tưởng tượng bạn đang dạy một học sinh trở thành chuyên gia an ninh mạng. Bạn sẽ:

  1. Cho họ xem các hệ thống có lỗ hổng
  2. Giải thích những loại tấn công nào nên thử
  3. Góp ý cho họ khi họ thành công hay thất bại
  4. Để họ thực hành cho đến khi giỏi hơn

Về cơ bản, dự án này cũng làm điều tương tự, nhưng với AI:

🤖 Mô hình AI (Qwen) ← Học sinh
🕸️ Juice Shop ← Phòng thực hành  
🎯 Thuật toán PPO ← Phương pháp giảng dạy
📊 Phần thưởng ← Điểm số/Phản hồi
Ý tưởng chính của PPO

Khái niệm cốt lõi của Tối ưu hóa Chính sách Gần đúng trong việc huấn luyện an ninh mạng

Bộ dữ liệu: Những cuộc tấn công thật, kết quả thật

Trái tim của hệ thống này là một bộ dữ liệu chứa 240 lần thử kiểm thử xâm nhập thực tế vào Juice Shop. Mỗi mục trông như thế này (đã được đơn giản hóa):

{
  "state": {
    "user_id": 51,
    "auth": "Yes", 
    "headers": {"Authorization": "Bearer token..."}
  },
  "action": {
    "description": "SQL Injection - User enumeration",
    "difficulty": 2
  },
  "reward": 16,
  "success": true
}

Điều làm tôi ấn tượng khi lần đầu xem xét dữ liệu này là cảm giác rất con người của nó. Mỗi bản ghi đại diện cho một khoảnh khắc thực sự khi ai đó cố gắng tìm ra một lỗ hổng - đôi khi thành công, đôi khi không. Bộ dữ liệu có tỷ lệ thành công là 2.9% với phần thưởng trung bình là 18.7 điểm cho mỗi lần thử. Điều này phản ánh đúng thực tế của việc kiểm thử xâm nhập, nơi hầu hết các nỗ lực đều thất bại, nhưng những lần thành công lại vô cùng giá trị.

Kiến trúc Kỹ thuật: Cùng nhau phân tích

1. Tác tử (Agent - JuiceShopAgent)

Đây là "đôi tay" của hệ thống - nó thực sự tương tác với ứng dụng web có lỗ hổng:

  • Đăng ký người dùng mới cho mỗi phiên kiểm thử
  • Thực hiện hơn 25 loại tấn công khác nhau (SQL injection, XSS, duyệt thư mục, v.v.)
  • Ghi lại trạng thái ứng dụng trước và sau mỗi cuộc tấn công
  • Tính toán phần thưởng dựa trên sự thành công và mức độ nghiêm trọng của lỗ hổng

Tác tử này có thể thực hiện các cuộc tấn công phức tạp như:

  • Tấn công SQL injection bằng UNION SELECT trên các điểm cuối tìm kiếm
  • Vượt qua đăng nhập quản trị viên bằng admin@juice-sh.op'--
  • Duyệt thư mục với các đường dẫn được mã hóa như %25252e%25252e%25252f
  • Khai thác logic nghiệp vụ (đặt hàng với số lượng âm)

2. Bộ não (Huấn luyện PPO)

Đây là lúc mọi thứ trở nên thú vị. Hệ thống sử dụng Tối ưu hóa Chính sách Gần đúng, một loại học tăng cường. Hãy hình dung nó như một cách dạy AI cẩn thận:

Huấn luyện truyền thống: "Đây là câu trả lời đúng, hãy ghi nhớ nó"
Huấn luyện PPO: "Hãy thử nghiệm, nhận phản hồi, và cải thiện dần dần"

Thuật toán PPO đặc biệt tốt vì nó:

  • Sẽ không tạo ra những thay đổi đột ngột có thể làm hỏng quá trình học
  • Cân bằng giữa khám phá và khai thác (thử những điều mới so với sử dụng những gì đã hiệu quả)
  • Sử dụng các hàm giá trị để dự đoán thành công trong dài hạn
Sơ đồ phần thưởng cho Tác tử PPO

Cách tác tử PPO học hỏi từ phần thưởng trong môi trường an ninh mạng

3. Mô hình (Các mô hình ngôn ngữ Qwen)

Dự án hỗ trợ nhiều mô hình Qwen khác nhau:

  • Qwen2.5-1.5B-Instruct: Cần 6-8GB VRAM, tốt cho việc thử nghiệm
  • Qwen2.5-3B-Instruct: Cần 10-12GB VRAM, chất lượng tốt hơn
  • Qwen2.5-7B-Instruct: Cần 16-20GB VRAM, chất lượng cao nhất

Mỗi mô hình có thể được huấn luyện theo hai cách:

  • Tinh chỉnh toàn bộ (Full Fine-tuning): Cập nhật tất cả các tham số của mô hình (độ chính xác cao hơn, cần nhiều bộ nhớ hơn)
  • LoRA (Low-Rank Adaptation): Chỉ cập nhật các lớp adapter nhỏ (nhanh hơn, cần ít bộ nhớ hơn)

Quá trình học: Nó thực sự hoạt động như thế nào

Thiết kế Phần thưởng (Reward Engineering): Trái tim của quá trình học

Chúng tôi đã triển khai một hệ thống phần thưởng cơ bản:

def calculate_smart_reward(challenges_before, challenges_after, 
                          status_code, response_text, difficulty):
    reward = 0
    
    # Phần thưởng lớn khi thực sự giải được các thử thách
    new_solved = challenges_after - challenges_before
    if new_solved:
        reward = len(new_solved) * difficulty * 20
    
    # Phần thưởng nhỏ hơn cho những nỗ lực có triển vọng
    if status_code == 200: reward += 10
    if 'admin' in response_text.lower(): reward += 8
    if 'sql' in response_text.lower(): reward += 6
    
    return reward

Điều này có nghĩa là AI sẽ được thưởng khi:

  • Thực sự giải được các thử thách (phần thưởng lớn)
  • Nhận được những phản hồi thú vị (phần thưởng trung bình)
  • Thực hiện những nỗ lực hợp lý (phần thưởng nhỏ)

Các biến thể huấn luyện: Những cách tiếp cận khác nhau

Dự án bao gồm một số chiến lược huấn luyện:

train_ppo.py: Cách tiếp cận chính

  • PPO tiêu chuẩn với các giá trị mặc định tốt
  • Hoạt động trong hầu hết các trường hợp sử dụng
  • 5 epochs, các tham số cân bằng

train_stable.py: Cách tiếp cận cẩn trọng

  • Tốc độ học thận trọng
  • Kiểm tra độ ổn định bổ sung
  • Cắt gradient và theo dõi sự biến động
  • Tốt nhất cho việc huấn luyện nhất quán, đáng tin cậy

train_long.py: Cách tiếp cận kỹ lưỡng

  • Hơn 10 epochs với lịch trình điều chỉnh tốc độ học
  • Dừng sớm khi đạt được mục tiêu
  • Lưu trữ điểm kiểm tra toàn diện
  • Tốt nhất cho các mô hình chất lượng sản phẩm

Kết quả: Thực sự đã học được những gì

Sau khi huấn luyện, các mô hình cho thấy những thay đổi hành vi thú vị:

Trước khi huấn luyện (phản hồi AI chung chung):

Hỏi: "Hành động kiểm thử xâm nhập tiếp theo?"
Đáp: "Tôi có thể giúp bạn với các thông tin chung về an ninh mạng..."

Sau khi huấn luyện (tập trung vào kiểm thử xâm nhập):

Hỏi: "User 51, Auth: Yes, Previous: SQL Injection. Hành động tiếp theo?"
Đáp: "Thử tấn công XSS vào tham số tìm kiếm hoặc kiểm tra các điểm cuối của quản trị viên"

Mô hình học được cách:

  • Nhận ra các mẫu lỗ hổng trong trạng thái ứng dụng
  • Đề xuất các cuộc tấn công kỹ thuật cụ thể thay vì lời khuyên chung chung
  • Liên kết các cuộc tấn công một cách hợp lý (sau SQL injection là leo thang đặc quyền)
  • Tập trung vào các mục tiêu có giá trị cao (điểm cuối quản trị viên, dữ liệu nhạy cảm)

Thách thức và Hạn chế

Những gì hoạt động tốt

  • Học tập nhất quán: Các mô hình cải thiện một cách đáng tin cậy qua các epochs huấn luyện
  • Độ chính xác kỹ thuật: Các cuộc tấn công đã học là các kỹ thuật kiểm thử xâm nhập hợp lệ
  • Nhận thức theo ngữ cảnh: Các mô hình xem xét trạng thái ứng dụng khi đề xuất hành động

Những gì vẫn còn khó khăn

  • Tỷ lệ thành công thấp: Ngay cả các mô hình đã được huấn luyện cũng không thường xuyên giải được các thử thách
  • Chi phí tính toán: Tinh chỉnh toàn bộ đòi hỏi tài nguyên GPU đáng kể
  • Khả năng khái quát hóa: Các mô hình được chuyên môn hóa cho Juice Shop và có thể không chuyển giao tốt cho các ứng dụng khác

Tại sao điều này lại quan trọng

Đối với An ninh mạng

Cách tiếp cận này cuối cùng có thể giúp:

  • Tự động hóa việc kiểm thử xâm nhập đối với các mẫu lỗ hổng phổ biến
  • Đào tạo chuyên gia bảo mật với việc học có sự hỗ trợ của AI
  • Đánh giá bảo mật liên tục cho các ứng dụng web

Đối với Nghiên cứu AI

Dự án này cho thấy:

  • Ứng dụng thực tế của học tăng cường vào các nhiệm vụ bảo mật trong thế giới thực
  • Tích hợp các mô hình ngôn ngữ với môi trường tương tác
  • Kỹ thuật thiết kế phần thưởng cho các lĩnh vực phức tạp, có phần thưởng thưa thớt

Cách AI học "Hack" (Chi tiết về mặt triển khai)

Cùng tìm hiểu các thuật toán và mã nguồn thực tế làm nên hệ thống này

Giới thiệu: Mở nắp capo xem bên trong

Hãy cùng nhau xem qua các thành phần chính.

1. Bộ máy tạo dữ liệu: JuiceShopAgent

Nền tảng: Thiết lập môi trường kiểm thử

Lớp JuiceShopAgent là "con ong thợ" thực sự thực hiện việc kiểm thử xâm nhập. Đây là cách nó được thiết lập:

class JuiceShopAgent:
    def __init__(self):
        self.session = requests.Session()
        self.session.timeout = 15
        self.current_user_id = None
        self.basket_id = None
        self.admin_email = "admin@juice-sh.op"

Trông có vẻ đơn giản, nhưng có sự khôn ngoan ở đây. Mỗi phiên kiểm thử đều có:

  • Phiên HTTP riêng (để quản lý cookie và trạng thái)
  • Thời gian chờ hợp lý (15 giây - đủ cho phản hồi, không quá lâu cho các yêu cầu bị treo)
  • Theo dõi ngữ cảnh người dùng (user_id và basket_id cho các cuộc tấn công có trạng thái)

Đăng ký người dùng thông minh: Tạo ngữ cảnh kiểm thử mới

Một khía cạnh thông minh là cách nó tạo người dùng mới cho mỗi lần kiểm thử:

def register_and_login(self) -> Tuple[str, str]:
    """Đăng ký người dùng mới và đăng nhập"""
    email = f"user{uuid.uuid4().hex[:8]}@juice-sh.op"
    password = f"Pass{random.randint(1000, 9999)}!"
    
    try:
        # Đăng ký với câu hỏi bảo mật ngẫu nhiên
        register_data = {
            "email": email,
            "password": password,
            "passwordRepeat": password,
            "securityQuestion": {
                "id": random.randint(1, 12),
                "answer": f"answer{random.randint(100, 999)}"
            }
        }
        res = self.session.post(f"{BASE_URL}/api/Users", json=register_data)
        
        if res.status_code == 201:
            # Đăng nhập và lấy token xác thực
            login_res = self.session.post(
                f"{BASE_URL}/rest/user/login", 
                json={"email": email, "password": password}
            )
            
            if login_res.status_code == 200:
                token = login_res.json()['authentication']['token']
                self.session.headers.update({"Authorization": f"Bearer {token}"})
                return email, password

Điều tôi đánh giá cao ở cách tiếp cận này là:

  • Danh tính duy nhất: Mỗi lần chạy kiểm thử có một ngữ cảnh người dùng mới hoàn toàn
  • Thông tin đăng nhập thực tế: Mật khẩu theo các mẫu phổ biến
  • Luồng xác thực chuẩn: Đăng nhập → lấy token → cập nhật headers
  • Xử lý lỗi: Thất bại một cách nhẹ nhàng nếu đăng ký không thành công

Kho vũ khí tấn công: Các kỹ thuật khai thác lỗ hổng đã được chứng minh

Trái tim của hệ thống là hàm get_proven_attacks(), trả về một danh sách các cuộc tấn công thực sự hoạt động trên Juice Shop:

def get_proven_attacks(self) -> List[Tuple[str, callable, int]]:
    """Lấy các cuộc tấn công đã được chứng minh là hoạt động với Juice Shop hiện tại"""
    
    attacks = [
        # Tấn công SQL Injection
        ("SQL Injection - Search bypass", 
         lambda: self.session.get(f"{BASE_URL}/rest/products/search?q=qwert%27))%20UNION%20SELECT%20id,%20email,%20password,%20%274%27,%20%275%27,%20%276%27,%20%277%27,%20%278%27,%20%279%27%20FROM%20Users--"), 2),
         
        ("SQL Injection - Login bypass admin", 
         lambda: self._admin_login_bypass(), 3),
         
        # Tấn công truy cập tệp tin
        ("Access confidential document", 
         lambda: self.session.get(f"{BASE_URL}/ftp/acquisitions.md"), 1),
         
        ("Poison null byte attack", 
         lambda: self.session.get(f"{BASE_URL}/ftp/eastere.gg%2500.md"), 3),
         
        # Lỗi logic nghiệp vụ
        ("Negative quantity order", 
         lambda: self._negative_quantity_working(), 3),
    ]
    
    return attacks

Mỗi cuộc tấn công được cấu trúc như sau:

  • Mô tả: Tên dễ đọc cho người
  • Hàm thực thi: Một hàm lambda hoặc phương thức thực hiện cuộc tấn công
  • Độ khó: Đánh giá bằng số nguyên (1=dễ, 3=khó)

Hệ thống phần thưởng thông minh: Dạy thông qua phản hồi

Đây là nơi hệ thống trở nên thực sự thông minh. Việc tính toán phần thưởng không chỉ đơn thuần nói "thành công" hay "thất bại" - nó cung cấp phản hồi tinh tế:

def calculate_smart_reward(challenges_before: set, challenges_after: set, 
                          status_code: int, response_text: str, difficulty: int,
                          url: str) -> int:
    """Tính toán phần thưởng thông minh dựa trên các lỗ hổng thực tế"""
    new_solved = challenges_after - challenges_before
    reward = 0
    
    # Phần thưởng cơ bản cho việc giải quyết các thử thách (phần thưởng RẤT LỚN)
    if new_solved:
        reward = len(new_solved) * difficulty * 20  # Lên đến 60 điểm!
        logger.info(f"🎉 ĐÃ GIẢI ĐƯỢC THỬ THÁCH: {list(new_solved)} - Phần thưởng: {reward}")
        return reward
    
    # Phần thưởng dựa trên trạng thái (phần thưởng trung bình)
    if status_code == 200:
        reward += 10
        
        # Phân tích nội dung để tìm lỗ hổng tiềm ẩn
        if response_text:
            content_lower = response_text.lower()
            
            # Các chỉ số của SQL injection
            if any(indicator in content_lower for indicator in 
                   ['email', 'password', 'users', 'admin', 'syntax error']):
                reward += 15
                
            # Các chỉ số truy cập tệp tin
            if any(indicator in content_lower for indicator in 
                   ['markdown', 'acquisitions', 'legal', 'confidential']):
                reward += 12
    
    # Ngay cả những nỗ lực thất bại cũng có thể cung cấp thông tin
    elif status_code == 401: reward += 3  # Yêu cầu xác thực - thú vị!
    elif status_code == 403: reward += 5  # Bị cấm - chúng ta đã tìm thấy gì đó
    elif status_code >= 500: reward += 6  # Lỗi máy chủ tiết lộ thông tin
    
    return max(reward, 2)  # Luôn cho một chút phần thưởng vì đã cố gắng

Cấu trúc phần thưởng này dạy cho AI:

  • Thành công lớn xứng đáng phần thưởng lớn (giải thử thách = 20-60 điểm)
  • Thất bại thú vị cũng có giá trị (nhận thông báo lỗi = 6-15 điểm)
  • Ngay cả nỗ lực cũng quan trọng (tối thiểu 2 điểm cho mọi hành động)

Điều tuyệt vời nằm ở việc phân tích nội dung - hệ thống nhận ra khi văn bản phản hồi chứa các từ khóa liên quan đến lỗ hổng, ngay cả khi nó chưa giải quyết hoàn toàn thử thách.

2. Bộ máy huấn luyện: Triển khai PPO

Chuẩn bị dữ liệu: Từ dữ liệu thô đến các ví dụ huấn luyện

Quá trình huấn luyện bắt đầu bằng việc chuyển đổi dữ liệu kiểm thử xâm nhập thô thành định dạng thân thiện với AI:

def build_dataset(tokenizer, data_path, split="train"):
    """Xây dựng bộ dữ liệu để huấn luyện"""
    ds = load_dataset("json", data_files=data_path, split=split)

    def create_prompt(sample):
        # Cải thiện câu lệnh cho mô hình trò chuyện
        system_prompt = "Bạn là một chuyên gia kiểm thử xâm nhập an ninh mạng. Hãy phân tích trạng thái hiện tại của ứng dụng web và đề xuất hành động chiến thuật tiếp theo để tìm ra các lỗ hổng."
        
        state_info = json.dumps(sample['state'], indent=2)
        
        prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
        prompt += f"<|im_start|>user\n"
        prompt += f"Trạng thái hiện tại của ứng dụng web:\n```json\n{state_info}\n```\n\n"
        prompt += f"Hành động kiểm thử xâm nhập tiếp theo nên là gì? Cung cấp một bước cụ thể, có thể hành động.<|im_end|>\n"
        prompt += f"<|im_start|>assistant\n"
        
        return prompt

    def tokenize(sample):
        sample["query"] = create_prompt(sample)
        encoded = tokenizer(
            sample["query"],
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        sample["input_ids"] = encoded["input_ids"].squeeze()
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

Kỹ thuật thiết kế câu lệnh (prompt engineering) ở đây rất quan trọng:

  • Định nghĩa vai trò rõ ràng: "Bạn là một chuyên gia kiểm thử xâm nhập an ninh mạng"
  • Cung cấp ngữ cảnh: Trạng thái JSON của ứng dụng web
  • Chỉ dẫn cụ thể: "Cung cấp một bước cụ thể, có thể hành động"
  • Định dạng phù hợp: Sử dụng mẫu trò chuyện của Qwen với các token <|im_start|>

Cấu hình PPO: Các tham số học tập

Cấu hình PPO là nơi phép màu xảy ra - những tham số này kiểm soát cách AI học:

ppo_config = PPOConfig(
    model_name=args.model_name,
    learning_rate=1e-6,          # Tốc độ học thận trọng
    batch_size=8,                # Xử lý 8 ví dụ cùng lúc
    mini_batch_size=2,           # Cập nhật PPO trên 2 ví dụ một lần
    gradient_accumulation_steps=4, # Kích thước batch hiệu quả = 8
    
    # Siêu tham số của PPO
    ppo_epochs=6,                # 6 bước tối ưu hóa cho mỗi batch
    gamma=0.99,                  # Hệ số chiết khấu phần thưởng trong tương lai
    lam=0.95,                    # GAE lambda để tính toán lợi thế
    cliprange=0.1,               # Giới hạn cập nhật chính sách (thận trọng!)
    cliprange_value=0.1,         # Giới hạn cập nhật hàm giá trị
    vf_coef=0.2,                 # Trọng số mất mát của hàm giá trị
    max_grad_norm=1.0,           # Cắt gradient
    target_kl=0.05,              # Mục tiêu phân kỳ KL (rất thận trọng)
    whiten_rewards=True,         # Chuẩn hóa phần thưởng
)

Tôi muốn nhấn mạnh một số lựa chọn quan trọng:

  • Giới hạn thận trọng (0.1): Ngăn mô hình thay đổi quá đột ngột
  • Tốc độ học nhỏ (1e-6): Học chậm và ổn định
  • Chuẩn hóa phần thưởng: Chuẩn hóa phần thưởng để mô hình không bị nhầm lẫn bởi thang đo

Vòng lặp huấn luyện: Nơi quá trình học diễn ra

Vòng lặp huấn luyện cốt lõi là nơi AI thực sự học hỏi:

for epoch in range(args.epochs):
    for batch in tqdm(ppo_trainer.dataloader, desc=f"Epoch {epoch + 1}"):
        query_tensors = batch["input_ids"]
        
        # Chuyển đổi batch tensor thành danh sách (yêu cầu của PPO)
        if isinstance(query_tensors, torch.Tensor) and query_tensors.dim() == 2:
            query_tensors = [query_tensors[i] for i in range(query_tensors.size(0))]

        # Tạo phản hồi từ mô hình hiện tại
        response_tensors = ppo_trainer.generate(
            query_tensors, 
            return_prompt=False, 
            **generation_kwargs
        )
        
        # Lấy phần thưởng từ bộ dữ liệu gốc
        rewards = []
        for i in range(len(query_tensors)):
            dataset_idx = (batch_count % len(dataset))
            reward_value = dataset[dataset_idx]["reward"]
            rewards.append(float(reward_value))
        
        reward_tensors = [torch.tensor(r, dtype=torch.float32) for r in rewards]

        # Bước tối ưu hóa PPO
        stats = ppo_trainer.step(query_tensors, response_tensors, reward_tensors)
        
        # Ghi lại tiến trình
        batch_mean_reward = sum(rewards) / len(rewards)
        value_loss = stats.get('ppo/loss/value', 0)
        policy_loss = stats.get('ppo/loss/policy', 0)

Trình tự là:

  1. Lấy câu hỏi từ bộ dữ liệu
  2. Tạo phản hồi bằng mô hình hiện tại
  3. Tính toán phần thưởng dựa trên các phản hồi
  4. Chạy cập nhật PPO để cải thiện mô hình
  5. Ghi lại số liệu thống kê để theo dõi tiến trình

Tham số tạo sinh: Kiểm soát sự sáng tạo của AI

Các cài đặt tạo sinh được điều chỉnh cẩn thận cho việc kiểm thử xâm nhập:

generation_kwargs = {
    "min_length": -1,
    "top_k": 40,                 # Xem xét 40 token tiếp theo hàng đầu
    "top_p": 0.85,               # Ngưỡng lấy mẫu hạt nhân (nucleus sampling)
    "do_sample": True,           # Bật lấy mẫu (không tham lam)
    "temperature": 0.6,          # Càng thấp, phản hồi càng tập trung
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 128,       # Độ dài phản hồi hợp lý
    "repetition_penalty": 1.05,  # Phạt nhẹ cho việc lặp lại
}

Những cài đặt này cân bằng giữa:

  • Sự sáng tạo (bật lấy mẫu, nhiệt độ hợp lý)
  • Sự tập trung (nhiệt độ thấp hơn, lọc top-k)
  • Chất lượng (phạt lặp lại, giới hạn độ dài)

5. Những điểm sáng trong mã nguồn và các thực hành tốt nhất

Triết lý xử lý lỗi

Xuyên suốt mã nguồn, có một mẫu xử lý lỗi nhất quán:

try:
    result = risky_operation()
    if result.status_code == 200:
        return process_success(result)
except Exception as e:
    logger.debug(f"Thao tác thất bại: {e}")
    # Trả về giá trị mặc định hợp lý thay vì làm sập chương trình
    mock_response = requests.Response()
    mock_response.status_code = 500
    return mock_response

Cách tiếp cận này:

  • Ghi lại các vấn đề mà không dừng thực thi
  • Cung cấp phản hồi giả lập để quá trình huấn luyện tiếp tục
  • Thoái hóa một cách nhẹ nhàng khi các thành phần thất bại

Quản lý bộ nhớ

Mã nguồn rất cẩn thận về bộ nhớ GPU:

# Sử dụng các kiểu dữ liệu phù hợp
model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # Độ chính xác nửa vời giúp tiết kiệm bộ nhớ
    device_map="auto",           # Tự động phân phối trên GPU/CPU
)

# Bật gradient checkpointing
ppo_config = PPOConfig(
    gradient_checkpointing=True,  # Đánh đổi tính toán lấy bộ nhớ
    # ...
)

Kết thúc phần 1

Kim Pham - 19.06.2025