Back to Blog

Training a Korean Visual Document Retriever Model

Most retrieval systems work on text. You embed a query, embed a corpus of text chunks, and find the nearest neighbors. But real documents — government forms, financial reports, textbooks, product manuals — are not plain text. They are rendered pages: tables with borders, multi-column layouts, figures, scanned stamps, and handwritten annotations. When you run OCR on these, you lose spatial structure. When you throw away the image entirely, you lose everything the layout communicates.

Korean documents add another layer of difficulty: Korean-language corpora are severely underrepresented in vision-language retrieval research. Most public training sets are English-only, and models trained on them degrade sharply on Korean text rendered inside page images.

This post documents Ko-VDR: a multimodal document retrieval model that takes a text query and returns the correct document page image, trained on a mixed Korean-English corpus of 325k pairs. I'll walk through the full pipeline — architecture, data, loss stack, training tricks, and evaluation — with as much detail as I used when actually building it.

1. The Problem: Cross-Modal Document Retrieval

The task is simple to state. Given a text query — "What is the refund policy?" — return the page image from a corpus of PDFs that answers it. No OCR, no text extraction, no chunking. The model must understand both the natural language query and the visual content of the document page.

The standard solution is a dual-encoder (bi-encoder) retriever:

Text Query  ──[Encoder]──►  query embedding  ─┐
                                               ├─► cosine_sim → rank → retrieve
Doc Image   ──[Encoder]──►  doc   embedding  ─┘

Both the query and each document page are mapped into the same embedding space — a list of floating-point numbers. Documents whose vectors point in the same direction as the query vector are considered relevant. At inference time, all document embeddings are pre-computed offline and stored in a FAISS index. Retrieval is then just an approximate nearest-neighbor search — fast regardless of corpus size.

The challenge is teaching the encoder that the query "refund policy" and a specific page image containing "환불 정책" (Korean for "refund policy") in a table are semantically equivalent. That is what training is for.

2. Model Architecture

2.1 Base Model: Qwen3-VL-Embedding-2B

The backbone is Qwen/Qwen3-VL-Embedding-2B, a 2B-parameter vision-language model designed specifically for embedding tasks. It is a decoder-only transformer (causal LM) that processes interleaved sequences of text tokens and visual patch tokens.

When a document page image is fed in, the VLM processor divides it into 28×28-pixel patches, encodes each patch into a vector via a small vision encoder, and inserts the resulting patch tokens into the same sequence as the text tokens. Self-attention then runs over the entire mixed sequence — every text token attends to every image patch, and vice versa. This cross-modal attention is what allows the model to relate "refund policy" in the query to visual tokens in the relevant table region.

The final embedding is extracted via EOS token pooling — the hidden state at the last non-padding token position:

sentence_embedding = token_embeddings[arange, last_non_padding_token_idx]

This gives a single 4096-dimensional vector per input. The model is wrapped in a SentenceTransformer as a two-module sequential:

model[0] = Transformer(Qwen3-VLModel)   # → token_embeddings  (B, seq_len, 4096)
model[1] = Pooling(...)                 # → sentence_embedding (B, 4096)  [EOS token]

One important flag: model[0].unpad_inputs = False. Qwen3-VL uses visual positional embeddings that break when Flash Attention's variable-length unpacking mode is active. This flag is non-obvious and easy to miss; without it the training run produces silently incorrect embeddings.

Key model configuration:

model = SentenceTransformer("Qwen/Qwen3-VL-Embedding-2B",
    model_kwargs={
        "attn_implementation": "flash_attention_2",
        "torch_dtype": torch.bfloat16
    },
    processor_kwargs={
        "max_pixels": 1280 * 28 * 28,
        "min_pixels": 4 * 28 * 28
    }
)
model[0].unpad_inputs = False

2.2 LoRA Fine-Tuning

Qwen3-VL-Embedding-2B has ~2 billion parameters. Full fine-tuning requires ~16 GB of optimizer state alone (Adam stores two running averages per parameter). Instead we use LoRA (Low-Rank Adaptation): freeze the original weights and inject small trainable matrices into the attention and FFN projections.

LoRA replaces each weight update ΔW with a low-rank factorization B·A, where A ∈ ℝ^{r×d_in} and B ∈ ℝ^{d_out×r}. With r=32, only ~1.5% of parameters are trained, reducing optimizer state from ~16 GB to ~240 MB.

lora_cfg = LoraConfig(
    r=32, lora_alpha=32,
    dropout=0.0, bias="none",
    task_type="FEATURE_EXTRACTION",
    target_modules=["q_proj","k_proj","v_proj","up_proj","down_proj","gate_proj"]
)
model.add_adapter(lora_cfg)
# freeze everything that isn't LoRA
for name, param in inner.named_parameters():
    if "lora_" not in name.lower():
        param.requires_grad = False

Critically, the vision encoder is intentionally frozen. LoRA is only applied to the LLM backbone's attention and feed-forward layers. The visual feature extractor has a different module naming convention, which naturally excludes it from the LoRA target modules, and keeping it frozen stabilizes training significantly.

3. Dataset and Data Pipeline

3.1 Training Data Sources

The training corpus is a mixed Korean-English dataset of 325,023 pairs built from four sources:

Source Images Language Notes
colpali ~118k English vidore/colpali_train_set
llamaindex ~54k English VDR evaluation set
kovdr_pub ~7.5k Korean Korean VDR public release
kovdr_priv ~28k Korean Korean VDR private; gated redistribution

Total: 325,023 training rows · 207,522 unique images · ~74 GB. Each row has one anchor (text query), one positive (matching page image path), and seven hard negative image paths. The Korean and English data are intentionally mixed within the same corpus rather than kept separate — this cross-lingual mixing is what drives the model's ability to match English queries against Korean document pages and vice versa.

3.2 Stage 1 — Image Extraction

All four source datasets are HuggingFace Arrow datasets containing PIL images. Decoding each image to PIL and re-encoding is ~100× slower than copying the raw compressed bytes directly:

ds = ds.cast_column("image", HFImage(decode=False))  # raw bytes, no PIL decode
img_bytes = img.get("bytes")
out_path.write_bytes(img_bytes)                       # ~100× faster

JPEG sources are written as JPEG at quality 95 (perceptually lossless for document scans); PNG sources are written as PNG. The pipeline is resume-friendly — it skips already-extracted files — which matters when extraction runs over 74 GB of data. The output is a pre-mining dataset of ~325k rows with schema (anchor, positive, source, doc_id).

3.3 Stage 2 — Hard Negative Mining

Random negatives are easy for the model to reject — it quickly learns to separate completely unrelated queries from documents. Hard negatives are semantically similar but incorrect documents: pages that look like they could answer the query but don't. Training on these pushes the model to learn fine-grained distinctions.

We mine hard negatives using a larger model than the one we're training: Qwen/Qwen3-VL-Embedding-8B. Better embeddings from the 8B model → higher quality negatives for the 2B model. The mining procedure:

  1. Embed all 207,522 unique document page images with the 8B model into a FAISS index.
  2. For each query, retrieve the top-K nearest document images.
  3. Filter candidates: keep only those where s_pos − s_neg ≥ 0.05 (absolute margin) — negatives that are within 0.05 cosine similarity of the positive are discarded as likely false negatives.
  4. Select the top 7 remaining candidates as hard negatives.
# absolute_margin filter: discard negatives within `margin` of positive
removed = scores + absolute_margin > max_positive_scores

Mining runs on GPUs 4–7 in a multiprocess pool, with faiss_batch_size=16384 tuned for A6000 48GB cards. Output is streamed to JSONL first (to avoid memory spikes), then saved as a multi-shard HuggingFace dataset (50k rows per shard). The final dataset is uploaded to the Hub with the corpus and train splits kept separate to avoid redundant image storage.

4. The Loss Stack

The full loss is a three-layer nested structure:

MatryoshkaLoss
  └── SelfGuideCachedMultipleNegativesRankingLoss  (custom)
        └── CachedMultipleNegativesRankingLoss  (upstream)

Each layer addresses a distinct problem. Let me unpack them bottom-up.

4.1 InfoNCE: The Foundation

The core training signal is InfoNCE (also called MultipleNegativesRankingLoss in sentence-transformers). Given a batch of N (query, positive document) pairs, every other document in the batch becomes a negative for each query:

L_i = -log [ exp(s_pos × scale) / Σ_j exp(s_ij × scale) ]
    = -(positive_score − logsumexp(all_scores))

With scale=20.0 (the inverse temperature 1/τ) and batch size 128, each query trains against 127 negatives simultaneously. The loss is simply cross-entropy: the model must assign the highest probability to the diagonal of the N×N similarity matrix.

More negatives per query = harder task = stronger model. Everything else in this system is a technique to increase the effective number of negatives while staying within hardware limits.

4.2 GradCache: The 3-Phase Memory Trick

To compute InfoNCE, all N embeddings must be in GPU memory at once to build the N×N similarity matrix. For N=128 document page images, each generating up to 1280 visual patch tokens × 4096 dimensions, this is gigabytes of intermediate activation tensors — the GPU runs out of memory.

GradCache (Gao et al. 2021) solves this by decoupling the embedding step from the loss step. It runs in three phases:

Phase 1 — Embed everything in mini-batches, no gradient tracking:

for mini_batch in split_into_chunks(full_batch, size=4):
    with torch.no_grad():             # don't build computation graph
        emb = model(mini_batch)       # shape: [4, 4096]
        emb = emb.detach()            # discard graph, keep numbers
        cache.append(emb.requires_grad_(True))  # make artificial leaf

No graph = no activation memory overhead. We hold only 4 samples' worth of tensors at a time, discarding them after each chunk. The output is 128 cached embedding vectors that, after .detach().requires_grad_(True), become artificial leaf nodes — they have a .grad slot but no connection to the model weights.

Phase 2 — Compute the full loss on cached embeddings, backward to fill embedding gradients:

loss = InfoNCE(all_cached_embeddings)  # full 128×128 matrix
loss.backward()                         # fills embedding.grad = ∂L/∂embedding

The model is not called here — only the cached vectors are used. No activation graphs from the model exist in memory. The 128×128 similarity matrix is unavoidable (it is the loss), but everything else is minimal.

Phase 3 — Re-embed each mini-batch with gradients, chain via dot-product surrogate:

for mini_batch, grad_slice in zip(chunks, grad_slices):
    emb = model(mini_batch)                     # build graph for 4 samples
    surrogate = (emb * grad_slice).sum()        # dot product shortcut
    surrogate.backward()                        # ∂L/∂weights accumulated

The dot product (re_embedding × ∂L/∂embedding).sum() is a mathematical surrogate: when differentiated w.r.t. the model weights, it produces the exact same gradient as running the full loss through the full 128-sample graph — but only 4 samples' activations exist in memory at any moment.

Memory footprint comparison:

                Model     Activation    Similarity
                weights   graph         matrix
                ───────   ──────────    ──────────
Naïve:          fixed     128 samples   128×128   ← OOM
GradCache Ph1:  fixed     4 samples     none      ← tiny
GradCache Ph2:  fixed     none          128×128   ← moderate
GradCache Ph3:  fixed     4 samples     none      ← tiny

Net result: exact gradients of a 128-sample training step, while never holding more than 4 samples' computation graph in memory at once. In the code: per_device_train_batch_size=128 (the logical batch), mini_batch_size=4 (the GradCache chunk for embedding phases only).

4.3 RandContext: Keeping Phases 1 and 3 Identical

Phase 1 and Phase 3 both call model(mini_batch). If dropout uses different random masks between the two calls, the re-embedded vectors in Phase 3 differ from those cached in Phase 1, making the dot-product surrogate mathematically invalid.

RandContext fixes this by snapshotting the complete random number generator state (CPU + every GPU) before Phase 1, then replaying it in Phase 3:

rand_state = RandContext(*inputs)    # save RNG state
# Phase 1 runs with dropout mask A...
with rand_state:                     # rewind to saved state
    re_embed(...)                    # produces bit-for-bit identical output

One subtle point: the RandContext is created after entering the gradient context, so the saved state exactly matches the state Phase 3 will replay — including the gradient-related RNG state differences between no_grad and with_grad modes.

4.4 Self-Guide Filter: Removing False Negatives

In-batch negatives assume that any document which isn't the labeled positive is irrelevant. But in a batch of 128 document pages, there may be pages that also answer the query — they're labeled as negatives only because the dataset has one positive per query. Penalizing a model for ranking these highly sends the wrong training signal.

The Self-Guide Filter (from the Qwen3-Embedding paper) masks out suspiciously relevant negatives before the softmax:

# For each query qᵢ with positive dᵢ:
threshold = s(qᵢ, dᵢ) - self_guide_margin   # raw cosine, stop-grad

# mask[i,j] = True iff s(qᵢ, dⱼ) > threshold
self_guide_mask = raw_q2d > threshold
self_guide_mask[row_indices, local_batch] = False  # never mask the true positive

# Apply after temperature scaling: masked cells contribute 0 to softmax
sim_matrices["query_to_doc"].masked_fill_(self_guide_mask, float("-inf"))

With self_guide_margin = -0.1 (the Qwen3 default), the threshold becomes s_pos + 0.1. A negative is masked only when it scores more than 0.1 above the positive — extremely conservative. Only the most obvious false negatives are removed. The threshold is computed on raw cosine similarities (before temperature scaling) to keep its interpretation consistent with the paper's formulation.

4.5 Matryoshka Loss: Seven Embedding Sizes for the Price of One

A 4096-dim embedding is expressive but slow to search in production. You might want 256-dim for a latency-sensitive deployment. Retraining the model at each desired size is expensive.

Matryoshka Representation Learning (Kusupati et al. 2022) trains the model so that the first K dimensions alone form a useful embedding for any K in a predefined set. The first 128 dimensions encode the most critical information; subsequent dimensions encode progressively finer details.

Implementation: after computing full 4096-dim embeddings, slice and re-normalize at each target dimension and compute InfoNCE at each size:

for dim in [2048, 1536, 1024, 768, 512, 256, 128]:
    truncated = F.normalize(embeddings[:, :dim], dim=-1)
    loss += InfoNCE(truncated)

The gradient of F.normalize is (I − x̂x̂ᵀ) / ||x||, which flows back cleanly through the slice. Earlier prefix dimensions receive gradient contributions from all 7 loss terms; later dimensions receive fewer. The model is thus incentivized to pack the most discriminative information into the first coordinates.

With GradCache, the embeddings are computed once in Phase 1. CachedLossDecorator then runs the Matryoshka loop on the cached vectors — no re-running the model for each dimension. Matryoshka adds 7 cheap matrix operations, not 7 forward passes.

Ko-VDR's Matryoshka dimensions and weights:

dims    = [2048, 1536, 1024, 768, 512, 256, 128]
weights = [1.0,  1.0,  1.0, 1.0, 1.0, 1.0, 1.0]  # uniform

5. Training Configuration

5.1 Optimization Hyperparameters

Parameter Value
Epochs1
Learning rate2e-5
Batch size (per device)128
GradCache mini-batch size4
Warmup ratio0.1
Max grad norm1.0
Precisionbfloat16
Gradient checkpointingTrue
Eval frequencyEvery 5% of total steps + on start
InfoNCE scale (temperature)20.0
Self-guide margin-0.1

5.2 Batch Construction

Two custom batch samplers are used:

5.3 Multi-GPU Training with DDP

Training runs on 4 GPUs via torchrun with NCCL backend. With gather_across_devices=False (our config), each GPU processes its own batch of 128 independently — each query sees 127 negatives. The gradients are averaged across GPUs via AllReduce after each step, equivalent to running 4 sequential batches of 128 on a single GPU. (With gather_across_devices=True, embeddings are shared via AllGather before loss computation, giving 511 negatives per query, but at significant communication overhead — not used here.)

Several DDP defaults silently break with this setup and are hardcoded in VLMRetrieverTrainingArguments:

self.prediction_loss_only = True
self.ddp_broadcast_buffers = False
self.ddp_find_unused_parameters = True
self.dataloader_drop_last = True  # forced under DDP

Why each flag matters:

There is one additional override: the VLMRetrieverTrainer disables the auto-generated model card callback. BaseModelCardCallback.on_init_end() preprocesses 1000 samples per column to compute dataset statistics. With 8 image columns × 1000 samples = 8,000 image decodes per DDP rank, this exhausts ~25 GB of RAM per rank before the first training step even starts.

The training launch command:

torchrun --nproc_per_node=4 --master_port=29500 train.py \
  --model_name_or_path="Qwen/Qwen3-VL-Embedding-2B" \
  --train_data='{"vdr": "/data_x/EMBEDDING/DATA/VL/preprocessed/260419_ko_en_vdr"}' \
  --num_epochs=1 \
  --per_device_train_batch_size=128 \
  --learning_rate=2e-5 \
  --use_lora=True \
  --self_guide_margin=-0.1 \
  --output_dir="./outputs/vlm-retriever" \
  --use_wandb=True

6. Evaluation: KoViDoRe-v2

The model is evaluated on KoViDoRe-v2, a set of four BEIR-format Korean-English multimodal retrieval benchmarks across four domains:

Each dataset provides three splits: corpus (document page images), queries (text queries), and qrels (relevance judgments as (query_id, corpus_id, score) triples). The primary metric is mean NDCG@10 across all four domains:

DCG@10  = Σ_{rank=1}^{10}  relevance(rank) / log2(rank + 1)
NDCG@10 = DCG@10 / IDCG@10    (IDCG = ideal DCG with relevant docs at top-10)

Rank 1 contributes rel / log2(2) = rel; rank 2 contributes rel × 0.63, and so on. A SequentialEvaluator runs the four domain evaluators in sequence and reports the mean NDCG@10 as the model selection metric. Evaluation runs every 5% of training steps (plus at initialization) with eval batch size 8. Corpus images are materialized to disk once and lazily decoded per batch to avoid OOM.

7. The Complete Forward-Backward Flow

Putting it all together, one training step looks like this:

Batch of 128 (text query, image path) pairs
     │
     ▼  BaseDataCollator.__call__
  ① extract dataset_name; ② warn on column order; ③ extract labels
  ④ resolve prompts (anchor: "Represent the user's input." / positive: "")
  ⑤ preprocess each column → prefixed tensors:
     anchor_input_ids:        (128, L_text)
     positive_pixel_values:   (Σ_patches, 4096)   ← all 128 images concatenated
     positive_image_grid_thw: (128, 3)              ← (H, W, T) per image
     negative_0_pixel_values: (Σ_patches, 4096)
     ...
     │
     ▼  collect_features → [{anchor_feats}, {positive_feats}, {neg_0_feats}, ...]
     │
     ▼  MatryoshkaLoss.forward
  ├─ installs CachedLossDecorator on CachedMNRL.calculate_loss
  └─► CachedMNRL.forward(sentence_features, labels)
       │
       ├─ PHASE 1 [no_grad, mini_batch=4]
       │    RandContext(*inputs) → save CPU+GPU RNG state
       │    model(mini_batch) → (4, 4096)  [no computation graph]
       │    .detach().requires_grad_() → artificial leaf
       │
       ├─ PHASE 2 [calculate_loss_and_cache_gradients]
       │    CachedLossDecorator loops over dims [2048, 1536, 1024, 768, 512, 256, 128]:
       │      shrink(full_embed, dim): [:dim] + F.normalize
       │      Self-Guide mask: raw_q2d.detach() > (s_pos + 0.1) → fill(-inf)
       │      sim_matrix * scale(20.0)
       │      positive_scores = sim_matrix[row_indices, local_batch]
       │      log_z = logsumexp(all_scores)
       │      per_sample_loss = -(positive_scores - log_z)
       │      .backward() → matryoshka leaf .grad filled
       │      chain back through shrink into full_embed.grad
       │    self.cache = full_embed.grad  (∂L/∂embedding)
       │
       └─ PHASE 3 [_backward_hook, fires on loss.backward()]
            torch.enable_grad()
            for each column, each mini-batch:
              replay RandContext → model(mini_batch) WITH grad, same dropout masks
              surrogate = dot(reps_mb, cached_grad) * grad_output
              surrogate.backward()
              → weight.grad = ∂L/∂weight  ✓
     │
     ▼  AdamW.step
  Update ~30M LoRA A & B params
  Skip ~1.97B frozen params

8. Why Each Technique Matters

Every design decision is solving either "how do we make the model better?" or "how do we make training feasible on real hardware?" Most do both at once.

Technique Effectiveness Efficiency
Large batch (128/GPU)127 negatives per query → harder task → better modelSlower per step
GradCacheEnables the large batch that improves qualityNever holds full computation graph in memory
RandContextCorrect gradient (wrong gradient = wasted training steps)Zero extra cost — just saves/restores RNG state
DDP (4 GPUs)Effective batch = 4×128 = 512 effective negatives~4× throughput
Matryoshka (7 dims)One model supports 7 speed/accuracy tradeoffs7 loss computations, 1 forward pass
Self-Guide filterCleaner signal by removing false negativesTiny overhead for mask computation
LoRA (r=32)Targets layers where domain adaptation matters most~240 MB optimizer state vs ~16 GB full fine-tune
Hard negative miningHarder negatives → finer-grained discriminationPre-computed offline with the larger 8B model
No-duplicate samplerEvery batch slot is a genuinely distinct negativeGreedy O(N) fill, negligible overhead
The fundamental tension in contrastive retrieval training: more negatives → better model, but more negatives → more GPU memory → OOM. Every major technique here is a way to escape that tension. GradCache decouples batch size from memory; DDP multiplies the effective batch; LoRA frees headroom for larger batches; CachedLossDecorator keeps Matryoshka from requiring 7× more forward passes. The result: each query trains against 511 negatives in one step, on hardware that couldn't hold even 16 full-resolution document images without these tricks.

9. Takeaways