Advertisement
3th1ca14aX0r

XCodeEval Dataset Growth Pipeline

Jun 16th, 2025
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.65 KB | None | 0 0
  1. #cat <<'EOF' > mini_run.py
  2. #!/usr/bin/env python3
  3. import os
  4. import re
  5. import json
  6. import zipfile
  7. import hashlib
  8. import tempfile
  9. import subprocess
  10. import argparse
  11. import time
  12. from tqdm import tqdm
  13. from collections import Counter
  14. from llama_cpp import Llama
  15. import difflib
  16. import psutil
  17. import chardet
  18. from func_timeout import func_timeout, FunctionTimedOut
  19.  
  20. # ===== Configuration =====
  21. MODEL_PATH = "/home/davetmire85/gguf_models/mistral-7b-instruct-v0.2.Q4_K_M.gguf"
  22. INPUT_PATH = "/home/davetmire85/sigil_inputs/rust_examples_utf8.jsonl"
  23. OUTPUT_DIR = "enriched_outputs"
  24. ZIP_PATH = "rich_results.zip"
  25. BUCKET_URI = "gs://sigil-transfer-bucket/rich_results.zip"
  26. HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_HUB_TOKEN", "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")
  27.  
  28. # Reduced for 26GB RAM environment
  29. BATCH_SIZE = 2
  30. MAX_TOKENS = 512
  31. CTX_WINDOW = 4096
  32. MAX_CODE_SIZE = 10000
  33.  
  34. BUG_TYPES = {
  35.     "OperatorSwap": r"(\+|\-|\*|\/|==|!=)",
  36.     "UnwrapToQuestionMark": r"\.unwrap\(\)",
  37.     "BoundaryError": r"for\s.*\bin\b\s(0\.\.|\.\.=)?\d+",
  38.     "LifetimeMissing": r"&\s*\w+\s*(?![:+])",
  39.     "TypeMismatch": r"as\s+[A-Za-z0-9_]+",
  40.     "IndexError": r"\[[A-Za-z0-9_]+\]",
  41. }
  42.  
  43. # Resource monitoring
  44. def memory_safe():
  45.     mem = psutil.virtual_memory()
  46.     return mem.available > (512 * 1024 * 1024)  # 512MB buffer
  47.  
  48. def disk_safe():
  49.     usage = psutil.disk_usage('/tmp')
  50.     return usage.free > (100 * 1024 * 1024)  # 100MB min free
  51.  
  52. def compute_uid(code: str) -> str:
  53.     if len(code) > MAX_CODE_SIZE:
  54.         raise ValueError(f"Code exceeds maximum size ({MAX_CODE_SIZE} chars)")
  55.     return hashlib.sha256(code.encode("utf-8")).hexdigest()[:32]
  56.  
  57. def analyze_edits(original: str, fixed: str) -> dict:
  58.     d = difflib.SequenceMatcher(None, original, fixed)
  59.     metrics = {
  60.         "similarity_score": d.ratio(),
  61.         "equal_cnt": 0,
  62.         "replace_cnt": 0,
  63.         "delete_cnt": 0,
  64.         "insert_cnt": 0,
  65.         "fix_ops_cnt": 0,
  66.         "changed_lines": []
  67.     }
  68.     for op, i1, i2, j1, j2 in d.get_opcodes():
  69.         if op == "equal":
  70.             metrics["equal_cnt"] += (i2 - i1)
  71.         elif op == "replace":
  72.             metrics["replace_cnt"] += (i2 - i1)
  73.             metrics["fix_ops_cnt"] += 1
  74.             metrics["changed_lines"].extend(range(i1, i2))
  75.         elif op == "delete":
  76.             metrics["delete_cnt"] += (i2 - i1)
  77.             metrics["fix_ops_cnt"] += 1
  78.             metrics["changed_lines"].extend(range(i1, i2))
  79.         elif op == "insert":
  80.             metrics["insert_cnt"] += (j2 - j1)
  81.             metrics["fix_ops_cnt"] += 1
  82.     metrics["changed_lines"] = sorted(set(metrics["changed_lines"]))
  83.     return metrics
  84.  
  85. def classify_bug(buggy: str, fixed: str) -> str:
  86.     for bug_type, pattern in BUG_TYPES.items():
  87.         if re.search(pattern, buggy) and not re.search(pattern, fixed):
  88.             return bug_type
  89.     return "ComplexChange"
  90.  
  91. def get_execution_status(code: str) -> dict:
  92.     # Check resources before starting build
  93.     if not memory_safe() or not disk_safe():
  94.         return {"status": "RESOURCE", "error": "Insufficient system resources"}
  95.    
  96.     with tempfile.TemporaryDirectory(prefix="sigil_") as tmpdir:
  97.         crate_name = f"sigil_{compute_uid(code)[:8]}"
  98.         src_dir = os.path.join(tmpdir, "src")
  99.         os.makedirs(src_dir, exist_ok=True)
  100.         with open(os.path.join(tmpdir, "Cargo.toml"), "w") as f:
  101.             f.write(f"""
  102. [package]
  103. name = "{crate_name}"
  104. version = "0.1.0"
  105. edition = "2021"
  106.  
  107. [profile.release]
  108. opt-level = 'z'  # Optimize for size
  109. lto = false
  110. """)
  111.         with open(os.path.join(src_dir, "main.rs"), "w") as f:
  112.             f.write(code)
  113.        
  114.         # Constrained build environment
  115.         env = os.environ.copy()
  116.         env["CARGO_BUILD_JOBS"] = "1"  # Limit parallelism
  117.        
  118.         try:
  119.             build = subprocess.run(
  120.                 ["cargo", "build", "--release"],
  121.                 cwd=tmpdir,
  122.                 capture_output=True,
  123.                 timeout=30,
  124.                 text=True,
  125.                 env=env
  126.             )
  127.             if build.returncode != 0:
  128.                 return {"status": "CE", "error": build.stderr[:500]}
  129.            
  130.             run = subprocess.run(
  131.                 [os.path.join(tmpdir, "target/release", crate_name)],
  132.                 capture_output=True,
  133.                 timeout=5,
  134.                 text=True
  135.             )
  136.             return {
  137.                 "status": "AC" if run.returncode == 0 else "RE",
  138.                 "signal": run.returncode,
  139.                 "output": run.stdout[:500],
  140.                 "exec_time": run.returncode
  141.             }
  142.         except subprocess.TimeoutExpired:
  143.             return {"status": "TLE", "error": "timeout"}
  144.         except Exception as e:
  145.             return {"status": "ERROR", "error": str(e)[:200]}
  146.  
  147. def load_model():
  148.     return Llama(
  149.         model_path=MODEL_PATH,
  150.         n_ctx=CTX_WINDOW,
  151.         n_threads=2,  # Fixed for 4 vCPU environment
  152.         n_gpu_layers=35,
  153.         n_batch=min(256, BATCH_SIZE * 2),
  154.         f16_kv=True,
  155.         use_mlock=True,
  156.         verbose=False
  157.     )
  158.  
  159. def detect_and_decode(filepath):
  160.     with open(filepath, "rb") as f:
  161.         raw = f.read()
  162.     detection = chardet.detect(raw)
  163.     encoding = detection["encoding"] or "utf-8"
  164.     try:
  165.         text = raw.decode(encoding)
  166.     except UnicodeDecodeError:
  167.         text = raw.decode("utf-8", errors="replace")
  168.     return text.replace("\r\n", "\n")
  169.  
  170. def load_entries(path):
  171.     entries = []
  172.     content = detect_and_decode(path)
  173.     for line in content.splitlines():
  174.         try:
  175.             data = json.loads(line)
  176.             code = data.get("before") or data.get("source_code")
  177.             if not code:
  178.                 continue
  179.             entries.append({
  180.                 "bug_code_uid": data.get("code_uid", compute_uid(code)),
  181.                 "bug_source_code": code,
  182.                 "metadata": {k: v for k, v in data.items() if k not in ["before", "source_code"]}
  183.             })
  184.         except Exception as e:
  185.             print(f"[!] Failed to parse: {e}")
  186.     return entries
  187.  
  188. def process_batch(llm: Llama, batch: list) -> list:
  189.     results = []
  190.     for entry in batch:
  191.         try:
  192.             # Check resources before each entry
  193.             if not memory_safe() or not disk_safe():
  194.                 raise ResourceWarning("Insufficient system resources")
  195.                
  196.             prompt = f"""<s>[INST] You are a senior Rust developer. Fix this code:
  197. {entry['bug_source_code'][:CTX_WINDOW//2]}
  198. Provide ONLY the fixed Rust code without explanations. [/INST] Fixed Code:\n"""
  199.             try:
  200.                 response = func_timeout(
  201.                     120,
  202.                     llm,
  203.                     args=(prompt,),
  204.                     kwargs={
  205.                         "max_tokens": MAX_TOKENS,
  206.                         "temperature": 0.1,
  207.                         "top_p": 0.9,
  208.                         "stop": ["</s>", "```", "\n\n\n", "[INST]"],
  209.                         "echo": False
  210.                     }
  211.                 )
  212.                 fix = response["choices"][0]["text"].strip().split("Fixed Code:")[-1].strip()
  213.             except FunctionTimedOut:
  214.                 fix = entry['bug_source_code']  # Fallback to original
  215.                 print("Model timeout, using original code")
  216.            
  217.             if not fix:
  218.                 raise ValueError("Empty fix generated")
  219.                
  220.             diff_metrics = analyze_edits(entry['bug_source_code'], fix)
  221.             execution = {
  222.                 "bug": get_execution_status(entry['bug_source_code']),
  223.                 "fix": get_execution_status(fix)
  224.             }
  225.             results.append({
  226.                 **entry,
  227.                 "fix_source_code": fix,
  228.                 "fix_code_uid": compute_uid(fix),
  229.                 "apr_id": compute_uid(entry['bug_code_uid'] + compute_uid(fix)),
  230.                 **diff_metrics,
  231.                 "bug_type": classify_bug(entry['bug_source_code'], fix),
  232.                 "execution": execution,
  233.                 "potential_dominant_fix_op": "replace",
  234.                 "resource_usage": {}
  235.             })
  236.         except Exception as e:
  237.             results.append({
  238.                 "error": str(e),
  239.                 "bug_code_uid": entry.get('bug_code_uid', 'unknown'),
  240.                 "partial_data": entry.get('metadata', {})
  241.             })
  242.     return results
  243.  
  244. def package_results(output_dir):
  245.     with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zipf:
  246.         for fname in os.listdir(output_dir):
  247.             zipf.write(os.path.join(output_dir, fname), fname)
  248.  
  249. def upload_results():
  250.     try:
  251.         subprocess.run(["gsutil", "cp", ZIP_PATH, BUCKET_URI], check=True)
  252.     except Exception as e:
  253.         print(f"Upload failed: {str(e)}")
  254.  
  255. def main():
  256.     parser = argparse.ArgumentParser(description="Rust Code Fix Pipeline")
  257.     parser.add_argument("--input", default=INPUT_PATH, help="Input JSONL path")
  258.     parser.add_argument("--output_dir", default=OUTPUT_DIR, help="Output directory")
  259.     parser.add_argument("--skip_existing", action="store_true", help="Skip processed shards")
  260.     args = parser.parse_args()
  261.  
  262.     os.environ["HUGGINGFACE_HUB_TOKEN"] = HUGGINGFACE_TOKEN
  263.     os.makedirs(args.output_dir, exist_ok=True)
  264.    
  265.     print("Loading model...")
  266.     llm = load_model()
  267.     print("Model loaded")
  268.  
  269.     print(f"Reading input from: {args.input}")
  270.     entries = load_entries(args.input)
  271.     print(f"Loaded {len(entries)} valid entries")
  272.     if not entries:
  273.         print("[!] No entries loaded. Exiting.")
  274.         return
  275.  
  276.     total_batches = (len(entries) + BATCH_SIZE - 1) // BATCH_SIZE
  277.     with tqdm(total=total_batches, desc="Processing") as pbar:
  278.         for batch_idx in range(0, len(entries), BATCH_SIZE):
  279.             shard_path = os.path.join(args.output_dir, f"shard_{batch_idx//BATCH_SIZE:04d}.jsonl")
  280.             if args.skip_existing and os.path.exists(shard_path):
  281.                 pbar.update(1)
  282.                 continue
  283.                
  284.             # System resource check
  285.             if not memory_safe():
  286.                 print("Low memory, pausing for 30s...")
  287.                 time.sleep(30)
  288.                
  289.             batch = entries[batch_idx:batch_idx + BATCH_SIZE]
  290.             results = process_batch(llm, batch)
  291.             with open(shard_path, "w") as f:
  292.                 for result in results:
  293.                     f.write(json.dumps(result, ensure_ascii=False) + "\n")
  294.             pbar.update(1)
  295.  
  296.     print("Packaging results...")
  297.     package_results(args.output_dir)
  298.    
  299.     print("Uploading results...")
  300.     upload_results()
  301.     print("Pipeline completed!")
  302.  
  303. if __name__ == "__main__":
  304.     main()
  305. #EOF
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement