Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lua/python/treesitter/init.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
local tsutil = require("nvim-treesitter.ts_utils")
local nodes = require("python.treesitter.nodes")

local PythonTreeSitter = {
Expand All @@ -8,7 +7,7 @@ local PythonTreeSitter = {
}

function PythonTreeSitter.test_ts_queries()
local current_node = tsutil.get_node_at_cursor()
local current_node = vim.treesitter.get_node()
if not current_node then
return
end
Expand Down
50 changes: 25 additions & 25 deletions lua/python/treesitter/nodes.lua
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
local PythonTreeSitterNodes = {}

-- part of the code from polarmutex/contextprint.nvim
local ts_utils = require("nvim-treesitter.ts_utils")
local ts_query = require("nvim-treesitter.query")
local parsers = require("nvim-treesitter.parsers")
local locals = require("nvim-treesitter.locals")
-- local vim_query = require("vim.treesitter.query")
local api = vim.api
local fn = vim.fn
local get_node_text = vim.treesitter.get_node_text
local parse = vim.treesitter.query.parse
if parse == nil then
parse = vim.treesitter.query.parse_query

-- Helper function to convert 0-indexed treesitter range to 1-indexed vim range
local function get_vim_range(node)
local start_row, start_col, end_row, end_col = node:range()
return start_row + 1, start_col + 1, end_row + 1, end_col + 1
end

-- Helper function to recurse through match captures (replaces locals.recurse_local_nodes)
local function recurse_captures(match, query, callback)
for id, node in pairs(match) do
local name = query.captures[id]
if name then
callback(id, node, name)
end
end
end

PythonTreeSitterNodes.count_parents = function(node)
Expand Down Expand Up @@ -54,21 +62,19 @@ PythonTreeSitterNodes.get_nodes = function(query, lang, defaults, bufnr)
return nil
end

local parser = parsers.get_parser(bufnr, lang)
local parser = vim.treesitter.get_parser(bufnr, lang)
local root = parser:parse()[1]:root()
local start_row, _, end_row, _ = root:range()
local results = {}
for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do
for pattern, match, metadata in parsed_query:iter_matches(root, bufnr, start_row, end_row) do
local sRow, sCol, eRow, eCol
local declaration_node
local type = "nil"
local name = "nil"
locals.recurse_local_nodes(match, function(_, node, path)
recurse_captures(match, parsed_query, function(_, node, path)
local idx = string.find(path, ".", 1, true)
local op = string.sub(path, idx + 1, #path)

-- local a1, b1, c1, d1 = vim.treesitter.get_node_range(node)

type = string.sub(path, 1, idx - 1)
if name == nil then
name = defaults[type] or "empty"
Expand All @@ -78,11 +84,7 @@ PythonTreeSitterNodes.get_nodes = function(query, lang, defaults, bufnr)
name = get_node_text(node, bufnr)
elseif op == "declaration" then
declaration_node = node
sRow, sCol, eRow, eCol = node:range()
sRow = sRow + 1
eRow = eRow + 1
sCol = sCol + 1
eCol = eCol + 1
sRow, sCol, eRow, eCol = get_vim_range(node)
end
end)

Expand Down Expand Up @@ -124,12 +126,12 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos
return nil
end

local parser = parsers.get_parser(bufnr, lang)
local parser = vim.treesitter.get_parser(bufnr, lang)
local root = parser:parse()[1]:root()
local start_row, _, end_row, _ = root:range()
local results = {}
local node_type
for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do
for pattern, match, metadata in parsed_query:iter_matches(root, bufnr, start_row, end_row) do
local sRow, sCol, eRow, eCol
local declaration_node
local type_node
Expand All @@ -139,14 +141,13 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos
-- local method_receiver = ""
-- ulog(match)

locals.recurse_local_nodes(match, function(_, node, path)
recurse_captures(match, parsed_query, function(_, node, path)
-- local idx = string.find(path, ".", 1, true)
-- The query may return multiple nodes, e.g.
-- (type_declaration (type_spec name:(type_identifier)@type_decl.name type:(type_identifier)@type_decl.type))@type_decl.declaration
-- returns { { @type_decl.name, @type_decl.type, @type_decl.declaration} ... }
local idx = string.find(path, ".[^.]*$") -- find last `.`
op = string.sub(path, idx + 1, #path)
local a1, b1, c1, d1 = vim.treesitter.get_node_range(node)
local dbg_txt = get_node_text(node, bufnr) or ""
if #dbg_txt > 100 then
dbg_txt = string.sub(dbg_txt, 1, 100) .. "..."
Expand All @@ -172,8 +173,7 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos
type_node = node
elseif op == 'declaration' or op == 'clause' then
declaration_node = node
sRow, sCol, eRow, eCol =
ts_utils.get_vim_range({ vim.treesitter.get_node_range(node) }, bufnr)
sRow, sCol, eRow, eCol = get_vim_range(node)
else
-- ulog('unknown op: ' .. op)
end
Expand All @@ -194,7 +194,7 @@ PythonTreeSitterNodes.get_all_nodes = function(query, lang, defaults, bufnr, pos
end
if type_node ~= nil and ntype then
-- ulog('type_only')
sRow, sCol, eRow, eCol = ts_utils.get_vim_range({ vim.treesitter.get_node_range(type_node) }, bufnr)
sRow, sCol, eRow, eCol = get_vim_range(type_node)
table.insert(results, {
type_node = type_node,
dim = { s = { r = sRow, c = sCol }, e = { r = eRow, c = eCol } },
Expand Down Expand Up @@ -266,7 +266,7 @@ PythonTreeSitterNodes.nodes_at_cursor = function(query, default, bufnr, ntype)
end

function PythonTreeSitterNodes.inside_function()
local current_node = ts_utils.get_node_at_cursor()
local current_node = vim.treesitter.get_node()
if not current_node then
return false
end
Expand Down
55 changes: 12 additions & 43 deletions scripts/minimal_init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,51 +18,20 @@ vim.cmd("set rtp+=" .. runtime_path)
require("luasnip").setup()
require("luasnip.extras.fmt")
require("luasnip.nodes.absolute_indexer")
require("nvim-treesitter.locals")
require("nvim-treesitter").setup()
require("mini.test").setup()
require("mini.doc").setup()
require("nvim-treesitter.configs").setup({
modules = {
"highlight",
},
sync_install = false,
auto_install = true,
ignore_install = {},
ensure_installed = {},
highlight = {
enable = true,
},
})

-- Clean path for use in a prefix comparison
---@param input string
---@return string
local function clean_path(input)
local pth = vim.fn.fnamemodify(input, ":p")
if vim.fn.has("win32") == 1 then
pth = pth:gsub("/", "\\")
end
return pth
end
-- Setup nvim-treesitter
require("nvim-treesitter").install("python")

local function ts_is_installed(lang)
local matched_parsers = vim.api.nvim_get_runtime_file("parser/" .. lang .. ".so", true) or {}
local configs = require("nvim-treesitter.configs")
local install_dir = configs.get_parser_install_dir()
if not install_dir then
return false
end
install_dir = clean_path(install_dir)
for _, path in ipairs(matched_parsers) do
local abspath = clean_path(path)
if vim.startswith(abspath, install_dir) then
return true
end
end
return false
end
vim.api.nvim_create_autocmd("FileType", {
group = vim.api.nvim_create_augroup("PythonTreesitter", { clear = true }),
pattern = "python",
desc = "Enable treesitter highlighting and indentation",
callback = function(event)
local buf = event.buf

if not ts_is_installed("python") then
vim.cmd("TSInstallSync python")
end
-- Start highlighting
pcall(vim.treesitter.start, buf, "python")
end,
})
Loading