From 22b90066b26725187e8d8a83e1d67736d12a59d7 Mon Sep 17 00:00:00 2001 From: Kevin Witlox Date: Tue, 23 Dec 2025 10:56:34 +0100 Subject: [PATCH] feat: migrate to nvim-treesitter 'main' --- lua/python/treesitter/init.lua | 3 +- lua/python/treesitter/nodes.lua | 50 +++++++++++++++--------------- scripts/minimal_init.lua | 55 +++++++-------------------------- 3 files changed, 38 insertions(+), 70 deletions(-) diff --git a/lua/python/treesitter/init.lua b/lua/python/treesitter/init.lua index b406475..8f3d6b9 100644 --- a/lua/python/treesitter/init.lua +++ b/lua/python/treesitter/init.lua @@ -1,4 +1,3 @@ -local tsutil = require("nvim-treesitter.ts_utils") local nodes = require("python.treesitter.nodes") local PythonTreeSitter = { @@ -8,7 +7,7 @@ local PythonTreeSitter = { } function PythonTreeSitter.test_ts_queries() - local current_node = tsutil.get_node_at_cursor() + local current_node = vim.treesitter.get_node() if not current_node then return end diff --git a/lua/python/treesitter/nodes.lua b/lua/python/treesitter/nodes.lua index c1315a4..2609bb0 100644 --- a/lua/python/treesitter/nodes.lua +++ b/lua/python/treesitter/nodes.lua @@ -1,17 +1,25 @@ local PythonTreeSitterNodes = {} -- part of the code from polarmutex/contextprint.nvim -local ts_utils = require("nvim-treesitter.ts_utils") -local ts_query = require("nvim-treesitter.query") -local parsers = require("nvim-treesitter.parsers") -local locals = require("nvim-treesitter.locals") --- local vim_query = require("vim.treesitter.query") local api = vim.api local fn = vim.fn local get_node_text = vim.treesitter.get_node_text local parse = vim.treesitter.query.parse -if parse == nil then - parse = vim.treesitter.query.parse_query + +-- Helper function to convert 0-indexed treesitter range to 1-indexed vim range +local function get_vim_range(node) + local start_row, start_col, end_row, end_col = node:range() + return start_row + 1, start_col + 1, end_row + 1, end_col + 1 +end + +-- Helper function to recurse through match captures (replaces locals.recurse_local_nodes) +local function recurse_captures(match, query, callback) + for id, node in pairs(match) do + local name = query.captures[id] + if name then + callback(id, node, name) + end + end end PythonTreeSitterNodes.count_parents = function(node) @@ -54,21 +62,19 @@ PythonTreeSitterNodes.get_nodes = function(query, lang, defaults, bufnr) return nil end - local parser = parsers.get_parser(bufnr, lang) + local parser = vim.treesitter.get_parser(bufnr, lang) local root = parser:parse()[1]:root() local start_row, _, end_row, _ = root:range() local results = {} - for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do + for pattern, match, metadata in parsed_query:iter_matches(root, bufnr, start_row, end_row) do local sRow, sCol, eRow, eCol local declaration_node local type = "nil" local name = "nil" - locals.recurse_local_nodes(match, function(_, node, path) + recurse_captures(match, parsed_query, function(_, node, path) local idx = string.find(path, ".", 1, true) local op = string.sub(path, idx + 1, #path) - -- local a1, b1, c1, d1 = vim.treesitter.get_node_range(node) - type = string.sub(path, 1, idx - 1) if name == nil then name = defaults[type] or "empty" @@ -78,11 +84,7 @@ PythonTreeSitterNodes.get_nodes = function(query, lang, defaults, bufnr) name = get_node_text(node, bufnr) elseif op == "declaration" then declaration_node = node - sRow, sCol, eRow, eCol = node:range() - sRow = sRow + 1 - eRow = eRow + 1 - sCol = sCol + 1 - eCol = eCol + 1 + sRow, sCol, eRow, eCol = get_vim_range(node) end end) @@ -124,12 +126,12 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos return nil end - local parser = parsers.get_parser(bufnr, lang) + local parser = vim.treesitter.get_parser(bufnr, lang) local root = parser:parse()[1]:root() local start_row, _, end_row, _ = root:range() local results = {} local node_type - for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do + for pattern, match, metadata in parsed_query:iter_matches(root, bufnr, start_row, end_row) do local sRow, sCol, eRow, eCol local declaration_node local type_node @@ -139,14 +141,13 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos -- local method_receiver = "" -- ulog(match) - locals.recurse_local_nodes(match, function(_, node, path) + recurse_captures(match, parsed_query, function(_, node, path) -- local idx = string.find(path, ".", 1, true) -- The query may return multiple nodes, e.g. -- (type_declaration (type_spec name:(type_identifier)@type_decl.name type:(type_identifier)@type_decl.type))@type_decl.declaration -- returns { { @type_decl.name, @type_decl.type, @type_decl.declaration} ... } local idx = string.find(path, ".[^.]*$") -- find last `.` op = string.sub(path, idx + 1, #path) - local a1, b1, c1, d1 = vim.treesitter.get_node_range(node) local dbg_txt = get_node_text(node, bufnr) or "" if #dbg_txt > 100 then dbg_txt = string.sub(dbg_txt, 1, 100) .. "..." @@ -172,8 +173,7 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos type_node = node elseif op == 'declaration' or op == 'clause' then declaration_node = node - sRow, sCol, eRow, eCol = - ts_utils.get_vim_range({ vim.treesitter.get_node_range(node) }, bufnr) + sRow, sCol, eRow, eCol = get_vim_range(node) else -- ulog('unknown op: ' .. op) end @@ -194,7 +194,7 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos end if type_node ~= nil and ntype then -- ulog('type_only') - sRow, sCol, eRow, eCol = ts_utils.get_vim_range({ vim.treesitter.get_node_range(type_node) }, bufnr) + sRow, sCol, eRow, eCol = get_vim_range(type_node) table.insert(results, { type_node = type_node, dim = { s = { r = sRow, c = sCol }, e = { r = eRow, c = eCol } }, @@ -266,7 +266,7 @@ PythonTreeSitterNodes.nodes_at_cursor = function(query, default, bufnr, ntype) end function PythonTreeSitterNodes.inside_function() - local current_node = ts_utils.get_node_at_cursor() + local current_node = vim.treesitter.get_node() if not current_node then return false end diff --git a/scripts/minimal_init.lua b/scripts/minimal_init.lua index ce9bf1b..c10d15a 100644 --- a/scripts/minimal_init.lua +++ b/scripts/minimal_init.lua @@ -18,51 +18,20 @@ vim.cmd("set rtp+=" .. runtime_path) require("luasnip").setup() require("luasnip.extras.fmt") require("luasnip.nodes.absolute_indexer") -require("nvim-treesitter.locals") -require("nvim-treesitter").setup() require("mini.test").setup() require("mini.doc").setup() -require("nvim-treesitter.configs").setup({ - modules = { - "highlight", - }, - sync_install = false, - auto_install = true, - ignore_install = {}, - ensure_installed = {}, - highlight = { - enable = true, - }, -}) --- Clean path for use in a prefix comparison ----@param input string ----@return string -local function clean_path(input) - local pth = vim.fn.fnamemodify(input, ":p") - if vim.fn.has("win32") == 1 then - pth = pth:gsub("/", "\\") - end - return pth -end +-- Setup nvim-treesitter +require("nvim-treesitter").install("python") -local function ts_is_installed(lang) - local matched_parsers = vim.api.nvim_get_runtime_file("parser/" .. lang .. ".so", true) or {} - local configs = require("nvim-treesitter.configs") - local install_dir = configs.get_parser_install_dir() - if not install_dir then - return false - end - install_dir = clean_path(install_dir) - for _, path in ipairs(matched_parsers) do - local abspath = clean_path(path) - if vim.startswith(abspath, install_dir) then - return true - end - end - return false -end +vim.api.nvim_create_autocmd("FileType", { + group = vim.api.nvim_create_augroup("PythonTreesitter", { clear = true }), + pattern = "python", + desc = "Enable treesitter highlighting and indentation", + callback = function(event) + local buf = event.buf -if not ts_is_installed("python") then - vim.cmd("TSInstallSync python") -end + -- Start highlighting + pcall(vim.treesitter.start, buf, "python") + end, +})