-
Notifications
You must be signed in to change notification settings - Fork 123
Add Nomic BERT model #440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Nomic BERT model #440
Conversation
Add support for nomic-ai/nomic-embed-text-v1.5 embedding model. Architecture: - Postnorm transformer (standard BERT-style) - SwiGLU activation (up * silu(gate)) - Rotary position embeddings (RoPE) with base 1000 - Combined Wqkv projection - No biases in attention and FFN layers - Mean pooling over non-masked tokens Tested against Python transformers with ~2e-6 precision.
462663b to
513b16c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for the Nomic BERT model family, specifically the nomic-embed-text-v1.5 embedding model. This is a BERT variant that incorporates modern architectural improvements including Rotary Position Embeddings (RoPE), SwiGLU activation functions, and bias-free attention/FFN layers.
Key Changes:
- Implementation of Nomic BERT architecture with combined Wqkv projection and gated FFN
- Mean pooling for generating sentence embeddings
- HuggingFace model loading support with parameter conversion from combined Wqkv weights
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| lib/bumblebee/text/nomic_bert.ex | Complete model implementation including embedder, encoder with postnorm transformer blocks, gated FFN (SwiGLU), and HuggingFace config/parameter mappings |
| test/bumblebee/text/nomic_bert_test.exs | Integration test that loads nomic-embed-text-v1.5 from HuggingFace and verifies output shapes and values against Python transformers |
| lib/bumblebee.ex | Registers NomicBertModel architecture and maps nomic_bert model type to BERT tokenizer |
| mix.exs | Adds Bumblebee.Text.NomicBert to documentation models list |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
lib/bumblebee/text/nomic_bert.ex
Outdated
| doc: "the activation function" | ||
| ], | ||
| rotary_embedding_base: [ | ||
| default: 10_000, |
Copilot
AI
Jan 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The description states that Nomic BERT uses "Rotary position embeddings (base 1000)", and the test verifies that rotary_embedding_base is 1000 for the loaded model. However, the default value here is set to 10_000. This mismatch could cause issues when users try to create models from scratch without loading from HuggingFace. The default should be 1000 to match the actual Nomic BERT specification.
| default: 10_000, | |
| default: 1000, |
- Fix max_positions doc: remove incorrect "vocabulary size" reference - Fix rotary_embedding_base default: change from 10_000 to 1000 - Fix normalization doc: correct "pre-normalization" to "post-normalization" - Fix position_ids doc: clarify usage with RoPE instead of position embeddings
jonatanklosko
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NomicBERT is not implemented in Python transformers directly, instead the source code lives on HuggingFace Hub and that model implementation is used on the fly. Normally we don't add such models, because the source is more likely to change and sometimes there is more than one source code.
That said, it seems that all nomic models point to the reference implementation from nomic-ai/nomic-bert-2048 and it's been a while since the last changes. So in this case I'd say it's ok to add it.
I added comments inline.
| assert_all_close( | ||
| outputs.hidden_state[[.., 0, 0..4]], | ||
| Nx.tensor([[1.3752, 0.7431, -4.6988, -0.6574, 2.1887]]), | ||
| atol: 1.0e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not need to override :atol here, if it fails with the default 1.0e-4, it most likely means the model implementation does not match the reference one, so that needs to be addressed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's worth having a look at the config attributes. For example
mlp_fc1_bias and mlp_fc2_bias seem relevant, but we don't import those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, i was disabling bias for every model. Now I'm importing the appropriate config.
lib/bumblebee/text/nomic_bert.ex
Outdated
| # Nomic BERT uses postnorm (like standard BERT): | ||
| # Each block: | ||
| # attn_output = attention(hidden_states) | ||
| # hidden_states = norm1(attn_output + hidden_states) | ||
| # ffn_output = ffn(hidden_states) | ||
| # hidden_states = norm2(ffn_output + hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems standard, is there any reason we cannot use Layers.Transformer.blocks, as we do for most other models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I initially thought the model implementation should be self-contained, but honestly I just didn't pay close attention to what Layers.Transformer.blocks already provides. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I initially thought the model implementation should be self-contained
That's what hf/transformers do, but it has tradeoffs, and in our case we ended up normalizing and sharing all of the core transformer logic. This makes it easier to maintain and add new models, because most LLMs introduce just a few differences.
Round intermediate_size to multiple of 256 to match Python's GatedMLP behavior.
|
using tiny model now |
I needed nomic-embed-text-v1.5 for a project, so I added support for it.
It's a BERT variant with a few differences:
I tested it against Python transformers and the outputs match within floating point precision (~2e-6).