-
Notifications
You must be signed in to change notification settings - Fork 64
feat: Lazy Spans and KV Blocks #249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nrfulton
wants to merge
57
commits into
generative-computing:main
Choose a base branch
from
nrfulton:nathan/conceptual_spans
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
49fdcf1
Adds cache smash code from the Project M codebase.
nrfulton a1a4eb7
rename to avoid clash b/w cache/ and cache.py
nrfulton 5989664
Adds cache flag to CBlock.
nrfulton fab35d9
Initial work on re-introducing span-ish KV caching.
nrfulton a648405
Adds a crystallization of the kv smash code
nrfulton ead3fe8
Adds KV cache smash.
nrfulton 1cd08ae
Adds example of kv cache smash.
nrfulton 212a768
Merge branch 'main' into nathan/kv_block_hack
avinash2692 e53e13b
Merge branch 'generative-computing:main' into nathan/kv_block_hack
nrfulton 806eef7
Merge branch 'main' into nathan/kv_block_hack
nrfulton 60f2178
Adds a SimpleComponent.
nrfulton ef6daf6
Adds a simple lazy example.
nrfulton 2fa6967
ollama generate walk.
nrfulton a9be6f6
Does gather() instead of awaiting on each thunk separately.
nrfulton 5ea5312
Refactor and bug fixes.
nrfulton 22ac0db
backend walks.
nrfulton 7187941
Adds heapcomponents.
nrfulton 6ea6d46
Make uncomputed mots logging less noisy.
nrfulton 4f37d96
adds a simple example.
nrfulton 152ede9
Cleans up fib example.
nrfulton 477275d
Adds parts() for instruction and genslot components.
nrfulton 976ac06
Don't call things components which are not components.
nrfulton 1797be4
ruff.
nrfulton 24de761
Starts adding some examples for a deepdive on sessions.
nrfulton ea3e789
blah
nrfulton b10ba6d
blah
nrfulton 2b07ae4
Add parts() to chat.
nrfulton 1080f3e
Merge branch 'main' into nathan/conceptual_spans
nrfulton e8be711
Fixes GenerativeSlot.parts()
nrfulton a75fd4c
Confirm assumption that RichDocument has no parts() for now.
nrfulton 6230138
Define parts() on TableQuery
nrfulton 1676956
Fixes ruff errors.
nrfulton 18707e4
Merge branch 'main' of ssh://github.com/generative-computing/mellea i…
nrfulton 009328c
Fixes error in HeapContext.add caught by mypy.
nrfulton 50b803a
Fixes mypy errors caused by shadowing
nrfulton ecceed7
Adds parts() definitions to the rest of the RichDocument components.
nrfulton eb8c557
fixes Instruction.parts()
nrfulton 82852eb
Improves warning message for Intrinsic.parts()
nrfulton 62b4a96
update comment on mify.parts()
nrfulton c087f12
parts() implementations for MObject components.
nrfulton 6a143f2
parts() implementation for Requirements.
nrfulton 509eb10
Some notes about the deep dives.
nrfulton eae8aca
Fixes line noise in previous commit.
nrfulton 7ec3aef
PARTIAL Merge branch 'nathan/kv_block_hack' into nathan/conceptual_spans
nrfulton 6095189
Finish resolving merge.
nrfulton 912d6ea
Examples are working (for some value of working -- results are garbage.
nrfulton 156f3db
precommit hooks are passing.
nrfulton 09d502c
Small changes to hf kv smash example.
nrfulton 7f21472
Fix fib example.
nrfulton ece648a
Remove accidental commit.
nrfulton a81818d
Removes unnecessary print statements.
nrfulton 1c5a03a
Removes HeapContext.
nrfulton 0e962b4
Intrinsics cannot surface parts because they always rewrite history
nrfulton f247196
removes dead helper code.
nrfulton adea6aa
removed code clone.
nrfulton 5d9a2c1
adds test.
nrfulton cce4fd2
Adds type:ignore because mypy 1.19.1 is buggy.
nrfulton File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| import asyncio | ||
| from mellea.stdlib.base import ( | ||
| SimpleContext, | ||
| Context, | ||
| CBlock, | ||
| ModelOutputThunk, | ||
| SimpleComponent, | ||
| ) | ||
| from mellea.backends import Backend | ||
| from mellea.backends.ollama import OllamaModelBackend | ||
|
|
||
| backend = OllamaModelBackend("granite4:latest") | ||
|
|
||
|
|
||
| async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk: | ||
| sc = SimpleComponent( | ||
| instruction="What is x+y? Respond with the number only.", x=x, y=y | ||
| ) | ||
| mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext()) | ||
| return mot | ||
|
|
||
|
|
||
| async def main(backend: Backend, ctx: Context): | ||
| fibs = [] | ||
| for i in range(100): | ||
| if i == 0 or i == 1: | ||
| fibs.append(CBlock(f"{i}")) | ||
| else: | ||
| fibs.append(await fib(backend, ctx, fibs[i - 1], fibs[i - 2])) | ||
|
|
||
| for x in fibs: | ||
| match x: | ||
| case ModelOutputThunk(): | ||
| print(await x.avalue()) | ||
| case CBlock(): | ||
| print(x.value) | ||
|
|
||
|
|
||
| asyncio.run(main(backend, SimpleContext())) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| import asyncio | ||
| from mellea.stdlib.base import ( | ||
| SimpleContext, | ||
| Context, | ||
| CBlock, | ||
| ModelOutputThunk, | ||
| SimpleComponent, | ||
| ) | ||
| from mellea.stdlib.requirement import Requirement | ||
| from mellea.backends import Backend | ||
| from mellea.backends.ollama import OllamaModelBackend | ||
| from typing import Tuple | ||
|
|
||
| backend = OllamaModelBackend("granite4:latest") | ||
|
|
||
|
|
||
| async def fib(backend: Backend, ctx: Context, x: CBlock, y: CBlock) -> ModelOutputThunk: | ||
| sc = SimpleComponent( | ||
| instruction="What is x+y? Respond with the number only.", x=x, y=y | ||
| ) | ||
| mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext()) | ||
| return mot | ||
|
|
||
|
|
||
| async def fib_main(backend: Backend, ctx: Context): | ||
| fibs = [] | ||
| for i in range(20): | ||
| if i == 0 or i == 1: | ||
| fibs.append(CBlock(f"{i}")) | ||
| else: | ||
| mot = await fib(backend, ctx, fibs[i - 1], fibs[i - 2]) | ||
| fibs.append(mot) | ||
|
|
||
| print(await fibs[-1].avalue()) | ||
| # for x in fibs: | ||
| # match x: | ||
| # case ModelOutputThunk(): | ||
| # n = await x.avalue() | ||
| # print(n) | ||
| # case CBlock(): | ||
| # print(x.value) | ||
|
|
||
|
|
||
| asyncio.run(fib_main(backend, SimpleContext())) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| import asyncio | ||
| from mellea.stdlib.base import ( | ||
| SimpleContext, | ||
| Context, | ||
| CBlock, | ||
| ModelOutputThunk, | ||
| SimpleComponent, | ||
| ) | ||
| from mellea.stdlib.requirement import Requirement | ||
| from mellea.backends import Backend | ||
| from mellea.backends.ollama import OllamaModelBackend | ||
| from typing import Tuple | ||
|
|
||
| backend = OllamaModelBackend("granite4:latest") | ||
|
|
||
|
|
||
| async def _fib_sample( | ||
| backend: Backend, ctx: Context, x: CBlock, y: CBlock | ||
| ) -> ModelOutputThunk | None: | ||
| sc = SimpleComponent( | ||
| instruction="What is x+y? Respond with the number only.", x=x, y=y | ||
| ) | ||
| answer_mot, _ = await backend.generate_from_context(action=sc, ctx=SimpleContext()) | ||
|
|
||
| # This is a fundamental thing: it means computation must occur. | ||
| # We need to be able to read this off at c.g. construction time. | ||
| value = await answer_mot.avalue() | ||
|
|
||
| try: | ||
| int(value) | ||
| return answer_mot | ||
| except: | ||
| return None | ||
nrfulton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| async def fib_sampling_version( | ||
| backend: Backend, ctx: Context, x: CBlock, y: CBlock | ||
| ) -> ModelOutputThunk | None: | ||
| for i in range(5): | ||
| sample = await _fib_sample(backend, ctx, x, y) | ||
| if sample is not None: | ||
| return sample | ||
| else: | ||
| continue | ||
| return None | ||
|
|
||
|
|
||
| async def fib_sampling_version_main(backend: Backend, ctx: Context): | ||
| fibs = [] | ||
| for i in range(20): | ||
| if i == 0 or i == 1: | ||
| fibs.append(CBlock(f"{i}")) | ||
| else: | ||
| mot = await fib_sampling_version(backend, ctx, fibs[i - 1], fibs[i - 2]) | ||
| fibs.append(mot) | ||
|
|
||
| for x_i, x in enumerate(fibs): | ||
| match x: | ||
| case ModelOutputThunk(): | ||
| n = await x.avalue() | ||
| print(n) | ||
| case CBlock(): | ||
| print(x.value) | ||
|
|
||
|
|
||
| asyncio.run(fib_sampling_version_main(backend, SimpleContext())) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| import asyncio | ||
| from mellea.stdlib.base import Context, CBlock, SimpleContext, ModelOutputThunk | ||
| from mellea.backends import Backend | ||
| from mellea.backends.ollama import OllamaModelBackend | ||
|
|
||
|
|
||
| async def main(backend: Backend, ctx: Context): | ||
| """ | ||
| In this example, we show how executing multiple MOTs in parallel should work. | ||
| """ | ||
| m_states = "Missouri", "Minnesota", "Montana", "Massachusetts" | ||
|
|
||
| poem_thunks = [] | ||
| for state_name in m_states: | ||
| mot, ctx = await backend.generate_from_context( | ||
| CBlock(f"Write a poem about {state_name}"), ctx | ||
| ) | ||
| poem_thunks.append(mot) | ||
|
|
||
| # Notice that what we have now is a list of ModelOutputThunks, none of which are computed. | ||
| for poem_thunk in poem_thunks: | ||
| assert type(poem_thunk) == ModelOutputThunk | ||
| print(f"Computed: {poem_thunk.is_computed()}") | ||
|
|
||
| # Let's run all of these in parallel. | ||
| await asyncio.gather(*[c.avalue() for c in poem_thunks]) | ||
|
|
||
| # Print out the final results, which are now computed. | ||
| for poem_thunk in poem_thunks: | ||
| print(f"Computed: {poem_thunk.is_computed()}") | ||
|
|
||
| # And let's print out the final results. | ||
| for poem_thunk in poem_thunks: | ||
| print(poem_thunk.value) | ||
|
|
||
|
|
||
| backend = OllamaModelBackend(model_id="granite4:latest") | ||
| asyncio.run(main(backend, SimpleContext())) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| from mellea.stdlib.base import SimpleContext, Context, CBlock, SimpleComponent | ||
| from mellea.backends import Backend | ||
| from mellea.backends.ollama import OllamaModelBackend | ||
| import asyncio | ||
|
|
||
|
|
||
| async def main(backend: Backend, ctx: Context): | ||
| a_states = "Alaska,Arizona,Arkansas".split(",") | ||
| m_states = "Missouri", "Minnesota", "Montana", "Massachusetts" | ||
|
|
||
| a_state_pops = dict() | ||
| for state in a_states: | ||
| a_state_pops[state], _ = await backend.generate_from_context( | ||
| CBlock(f"What is the population of {state}? Respond with an integer only."), | ||
| SimpleContext(), | ||
| ) | ||
| a_total_pop = SimpleComponent( | ||
| instruction=CBlock( | ||
| "What is the total population of these states? Respond with an integer only." | ||
| ), | ||
| **a_state_pops, | ||
| ) | ||
| a_state_total, _ = await backend.generate_from_context(a_total_pop, SimpleContext()) | ||
|
|
||
| m_state_pops = dict() | ||
| for state in m_states: | ||
| m_state_pops[state], _ = await backend.generate_from_context( | ||
| CBlock(f"What is the population of {state}? Respond with an integer only."), | ||
| SimpleContext(), | ||
| ) | ||
| m_total_pop = SimpleComponent( | ||
| instruction=CBlock( | ||
| "What is the total population of these states? Respond with an integer only." | ||
| ), | ||
| **m_state_pops, | ||
| ) | ||
| m_state_total, _ = await backend.generate_from_context(m_total_pop, SimpleContext()) | ||
|
|
||
| print(await a_state_total.avalue()) | ||
| print(await m_state_total.avalue()) | ||
|
|
||
|
|
||
| backend = OllamaModelBackend(model_id="granite4:latest") | ||
| asyncio.run(main(backend, SimpleContext())) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| from mellea.backends.huggingface import LocalHFBackend | ||
| from mellea.backends.model_ids import IBM_GRANITE_3_3_8B | ||
| from mellea.backends.types import ModelOption | ||
| from mellea.stdlib.base import CBlock, ChatContext | ||
| from mellea.stdlib.chat import Message | ||
| import asyncio | ||
|
|
||
|
|
||
| async def example(): | ||
| ctx = ChatContext(window_size=100) | ||
| ctx = ctx.add( | ||
| CBlock( | ||
| "Nathan Fulton is a Senior Research Scientist at the MIT-IBM Watson AI Lab, a joint venture between MIT and IBM.", | ||
| cache=True, | ||
| ) | ||
| ) | ||
| ctx = ctx.add( | ||
| CBlock( | ||
| "The MIT-IBM Watson AI Lab is located at 314 Main St, Cambridge, Massachusetts.", | ||
| cache=True, | ||
| ) | ||
| ) | ||
| ctx = ctx.add( | ||
| CBlock("The ZIP code for 314 Main St, Cambridge, Massachusetts is 02142") | ||
| ) | ||
|
|
||
| msg = Message( | ||
| role="user", | ||
| content="What is the likely ZIP code of Nathan Fulton's work address?", | ||
| ) | ||
| backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B) | ||
| mot = await backend._generate_from_context_with_kv_cache( | ||
| action=msg, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 64} | ||
| ) | ||
| # mot = await backend._generate_from_context_standard( | ||
| # action=msg, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 10} | ||
| # ) | ||
|
|
||
| result = await mot.avalue() | ||
| print(f".{result}.") | ||
|
|
||
| msg2 = Message( | ||
| role="user", | ||
| content="We know that Nathan does not work for a university. What is the likely name of Nathan's employer?", | ||
| ) | ||
| mot = await backend._generate_from_context_with_kv_cache( | ||
| action=msg2, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 64} | ||
| ) | ||
| result = await mot.avalue() | ||
| print(f".{result}.") | ||
|
|
||
|
|
||
| asyncio.run(example()) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.