From 0f11f27c8e26263e1985e519fbd406c4454f61b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 18 Dec 2025 13:33:08 +0000 Subject: [PATCH] Separate build data structs from serializations We have three versions of the build format (v1-3), where we converted prior versions to the latest and used the latest as the actual configuration. The downside of this approach is that every time we introduce a new version, we have to move all utility functions, etc. to the latest. This change decouples the configuration that is used to generate the build files from the concrete serializations. The core data structures move to the top-level module and v1-v3.rs now only contain plain data structures to (de)serialize with serde. --- build2cmake/src/config/common.rs | 24 -- build2cmake/src/config/compat.rs | 39 +++ build2cmake/src/config/deps.rs | 89 ++++++ build2cmake/src/config/mod.rs | 284 ++++++++++++++++--- build2cmake/src/config/v1.rs | 103 ++++++- build2cmake/src/config/v2.rs | 193 +++++++++---- build2cmake/src/config/v3.rs | 473 +++++++++++++------------------ build2cmake/src/main.rs | 5 +- 8 files changed, 813 insertions(+), 397 deletions(-) delete mode 100644 build2cmake/src/config/common.rs create mode 100644 build2cmake/src/config/compat.rs create mode 100644 build2cmake/src/config/deps.rs diff --git a/build2cmake/src/config/common.rs b/build2cmake/src/config/common.rs deleted file mode 100644 index 7b0eeced..00000000 --- a/build2cmake/src/config/common.rs +++ /dev/null @@ -1,24 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -#[non_exhaustive] -#[serde(rename_all = "lowercase")] -pub enum Dependency { - #[serde(rename = "cutlass_2_10")] - Cutlass2_10, - #[serde(rename = "cutlass_3_5")] - Cutlass3_5, - #[serde(rename = "cutlass_3_6")] - Cutlass3_6, - #[serde(rename = "cutlass_3_8")] - Cutlass3_8, - #[serde(rename = "cutlass_3_9")] - Cutlass3_9, - #[serde(rename = "cutlass_4_0")] - Cutlass4_0, - #[serde(rename = "cutlass_sycl")] - CutlassSycl, - #[serde(rename = "metal-cpp")] - MetalCpp, - Torch, -} diff --git a/build2cmake/src/config/compat.rs b/build2cmake/src/config/compat.rs new file mode 100644 index 00000000..c8553a4a --- /dev/null +++ b/build2cmake/src/config/compat.rs @@ -0,0 +1,39 @@ +use eyre::Result; +use serde::Deserialize; +use serde_value::Value; + +use super::{v1, v2, v3, Build}; + +#[derive(Debug)] +pub enum BuildCompat { + V1(v1::Build), + V2(v2::Build), + V3(v3::Build), +} + +impl<'de> Deserialize<'de> for BuildCompat { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + v1::Build::deserialize(value.clone()) + .map(BuildCompat::V1) + .or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2)) + .or_else(|_| v3::Build::deserialize(value.clone()).map(BuildCompat::V3)) + .map_err(serde::de::Error::custom) + } +} + +impl TryFrom for Build { + type Error = eyre::Error; + + fn try_from(compat: BuildCompat) -> Result { + match compat { + BuildCompat::V1(v1_build) => v1_build.try_into(), + BuildCompat::V2(v2_build) => v2_build.try_into(), + BuildCompat::V3(v3_build) => Ok(v3_build.into()), + } + } +} diff --git a/build2cmake/src/config/deps.rs b/build2cmake/src/config/deps.rs new file mode 100644 index 00000000..445d8cf1 --- /dev/null +++ b/build2cmake/src/config/deps.rs @@ -0,0 +1,89 @@ +use std::{collections::HashMap, sync::LazyLock}; + +use eyre::Result; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::Backend; + +pub static PYTHON_DEPENDENCIES: LazyLock = + LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap()); + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +#[non_exhaustive] +#[serde(rename_all = "lowercase")] +pub enum Dependency { + #[serde(rename = "cutlass_2_10")] + Cutlass2_10, + #[serde(rename = "cutlass_3_5")] + Cutlass3_5, + #[serde(rename = "cutlass_3_6")] + Cutlass3_6, + #[serde(rename = "cutlass_3_8")] + Cutlass3_8, + #[serde(rename = "cutlass_3_9")] + Cutlass3_9, + #[serde(rename = "cutlass_4_0")] + Cutlass4_0, + #[serde(rename = "cutlass_sycl")] + CutlassSycl, + #[serde(rename = "metal-cpp")] + MetalCpp, + Torch, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct PythonDependencies { + general: HashMap, + backends: HashMap>, +} + +impl PythonDependencies { + pub fn get_dependency(&self, dependency: &str) -> Result<&[String], DependencyError> { + match self.general.get(dependency) { + None => Err(DependencyError::GeneralDependency { + dependency: dependency.to_string(), + }), + Some(dep) => Ok(&dep.python), + } + } + + pub fn get_backend_dependency( + &self, + backend: Backend, + dependency: &str, + ) -> Result<&[String], DependencyError> { + let backend_deps = match self.backends.get(&backend) { + None => { + return Err(DependencyError::Backend { + backend: backend.to_string(), + }) + } + Some(backend_deps) => backend_deps, + }; + match backend_deps.get(dependency) { + None => Err(DependencyError::Dependency { + backend: backend.to_string(), + dependency: dependency.to_string(), + }), + Some(dep) => Ok(&dep.python), + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +struct PythonDependency { + nix: Vec, + python: Vec, +} + +#[derive(Debug, Error)] +pub enum DependencyError { + #[error("No dependencies are defined for backend: {backend:?}")] + Backend { backend: String }, + #[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")] + Dependency { backend: String, dependency: String }, + #[error("Unknown dependency: `{dependency:?}`")] + GeneralDependency { dependency: String }, +} diff --git a/build2cmake/src/config/mod.rs b/build2cmake/src/config/mod.rs index f9dc234d..816b4fea 100644 --- a/build2cmake/src/config/mod.rs +++ b/build2cmake/src/config/mod.rs @@ -1,50 +1,270 @@ +use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr}; + use eyre::Result; -use serde::Deserialize; -use serde_value::Value; +use serde::{Deserialize, Serialize}; -pub mod v1; +mod deps; +pub use deps::Dependency; -mod common; +mod compat; +pub use compat::BuildCompat; +mod v1; mod v2; +pub(crate) mod v3; + +use itertools::Itertools; -mod v3; -pub use common::Dependency; -pub use v3::{Backend, Build, General, Kernel, Torch}; +use crate::version::Version; -#[derive(Debug)] -pub enum BuildCompat { - V1(v1::Build), - V2(v2::Build), - V3(Build), +pub struct Build { + pub general: General, + pub kernels: HashMap, + pub torch: Option, } -impl<'de> Deserialize<'de> for BuildCompat { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +impl Build { + pub fn is_noarch(&self) -> bool { + self.kernels.is_empty() + } - v1::Build::deserialize(value.clone()) - .map(BuildCompat::V1) - .or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2)) - .or_else(|_| Build::deserialize(value.clone()).map(BuildCompat::V3)) - .map_err(serde::de::Error::custom) + pub fn supports_backend(&self, backend: &Backend) -> bool { + self.general.backends.contains(backend) } } -impl TryFrom for Build { - type Error = eyre::Error; +pub struct General { + pub name: String, + pub backends: Vec, + pub hub: Option, + pub python_depends: Option>, + + pub cuda: Option, + pub xpu: Option, +} + +impl General { + /// Name of the kernel as a Python extension. + pub fn python_name(&self) -> String { + self.name.replace("-", "_") + } + + pub fn python_depends(&self) -> Box> + '_> { + let general_python_deps = match self.python_depends.as_ref() { + Some(deps) => deps, + None => { + return Box::new(std::iter::empty()); + } + }; + + Box::new(general_python_deps.iter().flat_map(move |dep| { + match deps::PYTHON_DEPENDENCIES.get_dependency(dep) { + Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), + Err(e) => vec![Err(e.into())], + } + })) + } + + pub fn backend_python_depends( + &self, + backend: Backend, + ) -> Box> + '_> { + let backend_python_deps = match backend { + Backend::Cuda => self + .cuda + .as_ref() + .and_then(|cuda| cuda.python_depends.as_ref()), + Backend::Xpu => self + .xpu + .as_ref() + .and_then(|xpu| xpu.python_depends.as_ref()), + _ => None, + }; + + let backend_python_deps = match backend_python_deps { + Some(deps) => deps, + None => { + return Box::new(std::iter::empty()); + } + }; - fn try_from(compat: BuildCompat) -> Result { - match compat { - BuildCompat::V1(v1_build) => { - let v2_build: v2::Build = v1_build.try_into()?; - v2_build.try_into() + Box::new(backend_python_deps.iter().flat_map(move |dep| { + match deps::PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) { + Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), + Err(e) => vec![Err(e.into())], } - BuildCompat::V2(v2_build) => v2_build.try_into(), - BuildCompat::V3(v3_build) => Ok(v3_build), + })) + } +} + +pub struct CudaGeneral { + pub minver: Option, + pub maxver: Option, + pub python_depends: Option>, +} + +pub struct XpuGeneral { + pub python_depends: Option>, +} + +pub struct Hub { + pub repo_id: Option, + pub branch: Option, +} + +pub struct Torch { + pub include: Option>, + pub minver: Option, + pub maxver: Option, + pub pyext: Option>, + pub src: Vec, +} + +impl Torch { + pub fn data_globs(&self) -> Option> { + match self.pyext.as_ref() { + Some(exts) => { + let globs = exts + .iter() + .filter(|&ext| ext != "py" && ext != "pyi") + .map(|ext| format!("\"**/*.{ext}\"")) + .collect_vec(); + if globs.is_empty() { + None + } else { + Some(globs) + } + } + + None => None, + } + } +} + +pub enum Kernel { + Cpu { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + Cuda { + cuda_capabilities: Option>, + cuda_flags: Option>, + cuda_minver: Option, + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + Metal { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + Rocm { + cxx_flags: Option>, + depends: Vec, + rocm_archs: Option>, + hip_flags: Option>, + include: Option>, + src: Vec, + }, + Xpu { + cxx_flags: Option>, + depends: Vec, + sycl_flags: Option>, + include: Option>, + src: Vec, + }, +} + +impl Kernel { + pub fn cxx_flags(&self) -> Option<&[String]> { + match self { + Kernel::Cpu { cxx_flags, .. } + | Kernel::Cuda { cxx_flags, .. } + | Kernel::Metal { cxx_flags, .. } + | Kernel::Rocm { cxx_flags, .. } + | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), + } + } + + pub fn include(&self) -> Option<&[String]> { + match self { + Kernel::Cpu { include, .. } + | Kernel::Cuda { include, .. } + | Kernel::Metal { include, .. } + | Kernel::Rocm { include, .. } + | Kernel::Xpu { include, .. } => include.as_deref(), + } + } + + pub fn backend(&self) -> Backend { + match self { + Kernel::Cpu { .. } => Backend::Cpu, + Kernel::Cuda { .. } => Backend::Cuda, + Kernel::Metal { .. } => Backend::Metal, + Kernel::Rocm { .. } => Backend::Rocm, + Kernel::Xpu { .. } => Backend::Xpu, + } + } + + pub fn depends(&self) -> &[Dependency] { + match self { + Kernel::Cpu { depends, .. } + | Kernel::Cuda { depends, .. } + | Kernel::Metal { depends, .. } + | Kernel::Rocm { depends, .. } + | Kernel::Xpu { depends, .. } => depends, + } + } + + pub fn src(&self) -> &[String] { + match self { + Kernel::Cpu { src, .. } + | Kernel::Cuda { src, .. } + | Kernel::Metal { src, .. } + | Kernel::Rocm { src, .. } + | Kernel::Xpu { src, .. } => src, + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum Backend { + Cpu, + Cuda, + Metal, + Rocm, + Xpu, +} + +impl Display for Backend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Backend::Cpu => write!(f, "cpu"), + Backend::Cuda => write!(f, "cuda"), + Backend::Metal => write!(f, "metal"), + Backend::Rocm => write!(f, "rocm"), + Backend::Xpu => write!(f, "xpu"), + } + } +} + +impl FromStr for Backend { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "cpu" => Ok(Backend::Cpu), + "cuda" => Ok(Backend::Cuda), + "metal" => Ok(Backend::Metal), + "rocm" => Ok(Backend::Rocm), + "xpu" => Ok(Backend::Xpu), + _ => Err(format!("Unknown backend: {s}")), } } } diff --git a/build2cmake/src/config/v1.rs b/build2cmake/src/config/v1.rs index 5582d7ff..608f35d5 100644 --- a/build2cmake/src/config/v1.rs +++ b/build2cmake/src/config/v1.rs @@ -1,8 +1,13 @@ -use std::{collections::HashMap, fmt::Display, path::PathBuf}; +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, + path::PathBuf, +}; +use eyre::{bail, Result}; use serde::Deserialize; -use super::common::Dependency; +use super::{Backend, Dependency}; #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] @@ -63,3 +68,97 @@ impl Display for Language { } } } + +impl TryFrom for super::Build { + type Error = eyre::Error; + + fn try_from(build: Build) -> Result { + let universal = build + .torch + .as_ref() + .map(|torch| torch.universal) + .unwrap_or(false); + + let kernels = convert_kernels(build.kernels)?; + + let backends = if universal { + vec![ + Backend::Cpu, + Backend::Cuda, + Backend::Metal, + Backend::Rocm, + Backend::Xpu, + ] + } else { + let backend_set: BTreeSet = + kernels.values().map(|kernel| kernel.backend()).collect(); + backend_set.into_iter().collect() + }; + + Ok(Self { + general: super::General { + name: build.general.name, + backends, + hub: None, + python_depends: None, + cuda: None, + xpu: None, + }, + torch: build.torch.map(Into::into), + kernels, + }) + } +} + +fn convert_kernels(v1_kernels: HashMap) -> Result> { + let mut kernels = HashMap::new(); + + for (name, kernel) in v1_kernels { + if kernel.language == Language::CudaHipify { + // We need to add an affix to avoid conflict with the CUDA kernel. + let rocm_name = format!("{name}_rocm"); + if kernels.contains_key(&rocm_name) { + bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`") + } + + kernels.insert( + format!("{name}_rocm"), + super::Kernel::Rocm { + cxx_flags: None, + rocm_archs: kernel.rocm_archs, + hip_flags: None, + depends: kernel.depends.clone(), + include: kernel.include.clone(), + src: kernel.src.clone(), + }, + ); + } + + kernels.insert( + name, + super::Kernel::Cuda { + cuda_capabilities: kernel.cuda_capabilities, + cuda_flags: None, + cuda_minver: None, + cxx_flags: None, + depends: kernel.depends, + include: kernel.include, + src: kernel.src, + }, + ); + } + + Ok(kernels) +} + +impl From for super::Torch { + fn from(torch: Torch) -> Self { + Self { + include: torch.include, + minver: None, + maxver: None, + pyext: torch.pyext, + src: torch.src, + } + } +} diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 20b906bc..43e22117 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -1,15 +1,15 @@ -use std::{collections::HashMap, fmt::Display, path::PathBuf}; +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, + path::PathBuf, +}; -use eyre::{bail, Result}; +use eyre::Result; use serde::{Deserialize, Serialize}; +use super::{Backend, Dependency}; use crate::version::Version; -use super::{ - common::Dependency, - v1::{self, Language}, -}; - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct Build { @@ -117,85 +117,152 @@ pub enum Kernel { }, } -impl TryFrom for Build { +impl TryFrom for super::Build { type Error = eyre::Error; - fn try_from(build: v1::Build) -> Result { - let universal = build - .torch - .as_ref() - .map(|torch| torch.universal) - .unwrap_or(false); + fn try_from(build: Build) -> Result { + let kernels: HashMap = build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(); + + let backends = if build.general.universal { + vec![ + Backend::Cpu, + Backend::Cuda, + Backend::Metal, + Backend::Rocm, + Backend::Xpu, + ] + } else { + let backend_set: BTreeSet = + kernels.values().map(|kernel| kernel.backend()).collect(); + backend_set.into_iter().collect() + }; + Ok(Self { - general: General::from(build.general, universal), + general: General::from_v2(build.general, backends), torch: build.torch.map(Into::into), - kernels: convert_kernels(build.kernels)?, + kernels, }) } } impl General { - fn from(general: v1::General, universal: bool) -> Self { - Self { + fn from_v2(general: General, backends: Vec) -> super::General { + let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() { + Some(super::CudaGeneral { + minver: general.cuda_minver, + maxver: general.cuda_maxver, + python_depends: None, + }) + } else { + None + }; + + super::General { name: general.name, - universal, - cuda_maxver: None, - cuda_minver: None, - hub: None, + backends, + cuda, + hub: general.hub.map(Into::into), python_depends: None, + xpu: None, } } } -fn convert_kernels(v1_kernels: HashMap) -> Result> { - let mut kernels = HashMap::new(); - - for (name, kernel) in v1_kernels { - if kernel.language == Language::CudaHipify { - // We need to add an affix to avoid conflict with the CUDA kernel. - let rocm_name = format!("{name}_rocm"); - if kernels.contains_key(&rocm_name) { - bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`") - } - - kernels.insert( - format!("{name}_rocm"), - Kernel::Rocm { - cxx_flags: None, - rocm_archs: kernel.rocm_archs, - hip_flags: None, - depends: kernel.depends.clone(), - include: kernel.include.clone(), - src: kernel.src.clone(), - }, - ); +impl From for super::Hub { + fn from(hub: Hub) -> Self { + Self { + repo_id: hub.repo_id, + branch: hub.branch, } - - kernels.insert( - name, - Kernel::Cuda { - cuda_capabilities: kernel.cuda_capabilities, - cuda_flags: None, - cuda_minver: None, - cxx_flags: None, - depends: kernel.depends, - include: kernel.include, - src: kernel.src, - }, - ); } - - Ok(kernels) } -impl From for Torch { - fn from(torch: v1::Torch) -> Self { +impl From for super::Torch { + fn from(torch: Torch) -> Self { Self { include: torch.include, - minver: None, - maxver: None, + minver: torch.minver, + maxver: torch.maxver, pyext: torch.pyext, src: torch.src, } } } + +impl From for super::Kernel { + fn from(kernel: Kernel) -> Self { + match kernel { + Kernel::Cpu { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cpu { + cxx_flags, + depends, + include, + src, + }, + Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + }, + Kernel::Metal { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Metal { + cxx_flags, + depends, + include, + src, + }, + Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + } => super::Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + }, + Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + } => super::Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + }, + } + } +} diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs index 14c25418..770ec05a 100644 --- a/build2cmake/src/config/v3.rs +++ b/build2cmake/src/config/v3.rs @@ -1,78 +1,11 @@ -use std::{ - collections::{BTreeSet, HashMap}, - fmt::Display, - path::PathBuf, - str::FromStr, - sync::LazyLock, -}; - -use eyre::Result; -use itertools::Itertools; +use std::collections::HashMap; +use std::path::PathBuf; + use serde::{Deserialize, Serialize}; -use thiserror::Error; -use super::{common::Dependency, v2}; +use super::Dependency; use crate::version::Version; -#[derive(Debug, Error)] -enum DependencyError { - #[error("No dependencies are defined for backend: {backend:?}")] - Backend { backend: String }, - #[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")] - Dependency { backend: String, dependency: String }, - #[error("Unknown dependency: `{dependency:?}`")] - GeneralDependency { dependency: String }, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -struct PythonDependencies { - general: HashMap, - backends: HashMap>, -} - -impl PythonDependencies { - fn get_dependency(&self, dependency: &str) -> Result<&[String], DependencyError> { - match self.general.get(dependency) { - None => Err(DependencyError::GeneralDependency { - dependency: dependency.to_string(), - }), - Some(dep) => Ok(&dep.python), - } - } - - fn get_backend_dependency( - &self, - backend: Backend, - dependency: &str, - ) -> Result<&[String], DependencyError> { - let backend_deps = match self.backends.get(&backend) { - None => { - return Err(DependencyError::Backend { - backend: backend.to_string(), - }) - } - Some(backend_deps) => backend_deps, - }; - match backend_deps.get(dependency) { - None => Err(DependencyError::Dependency { - backend: backend.to_string(), - dependency: dependency.to_string(), - }), - Some(dep) => Ok(&dep.python), - } - } -} - -#[derive(Debug, Deserialize, Serialize)] -struct PythonDependency { - nix: Vec, - python: Vec, -} - -static PYTHON_DEPENDENCIES: LazyLock = - LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap()); - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct Build { @@ -83,16 +16,6 @@ pub struct Build { pub kernels: HashMap, } -impl Build { - pub fn is_noarch(&self) -> bool { - self.kernels.is_empty() - } - - pub fn supports_backend(&self, backend: &Backend) -> bool { - self.general.backends.contains(backend) - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct General { @@ -109,60 +32,6 @@ pub struct General { pub xpu: Option, } -impl General { - /// Name of the kernel as a Python extension. - pub fn python_name(&self) -> String { - self.name.replace("-", "_") - } - - pub fn python_depends(&self) -> Box> + '_> { - let general_python_deps = match self.python_depends.as_ref() { - Some(deps) => deps, - None => { - return Box::new(std::iter::empty()); - } - }; - - Box::new(general_python_deps.iter().flat_map(move |dep| { - match PYTHON_DEPENDENCIES.get_dependency(dep) { - Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), - Err(e) => vec![Err(e.into())], - } - })) - } - - pub fn backend_python_depends( - &self, - backend: Backend, - ) -> Box> + '_> { - let backend_python_deps = match backend { - Backend::Cuda => self - .cuda - .as_ref() - .and_then(|cuda| cuda.python_depends.as_ref()), - Backend::Xpu => self - .xpu - .as_ref() - .and_then(|xpu| xpu.python_depends.as_ref()), - _ => None, - }; - - let backend_python_deps = match backend_python_deps { - Some(deps) => deps, - None => { - return Box::new(std::iter::empty()); - } - }; - - Box::new(backend_python_deps.iter().flat_map(move |dep| { - match PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) { - Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), - Err(e) => vec![Err(e.into())], - } - })) - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct CudaGeneral { @@ -196,27 +65,6 @@ pub struct Torch { pub src: Vec, } -impl Torch { - pub fn data_globs(&self) -> Option> { - match self.pyext.as_ref() { - Some(exts) => { - let globs = exts - .iter() - .filter(|&ext| ext != "py" && ext != "pyi") - .map(|ext| format!("\"**/*.{ext}\"")) - .collect_vec(); - if globs.is_empty() { - None - } else { - Some(globs) - } - } - - None => None, - } - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")] pub enum Kernel { @@ -263,152 +111,217 @@ pub enum Kernel { }, } -impl Kernel { - pub fn cxx_flags(&self) -> Option<&[String]> { - match self { - Kernel::Cpu { cxx_flags, .. } - | Kernel::Cuda { cxx_flags, .. } - | Kernel::Metal { cxx_flags, .. } - | Kernel::Rocm { cxx_flags, .. } - | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum Backend { + Cpu, + Cuda, + Metal, + Rocm, + Xpu, +} + +impl From for super::Build { + fn from(build: Build) -> Self { + let kernels: HashMap = build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(); + + Self { + general: build.general.into(), + torch: build.torch.map(Into::into), + kernels, } } +} - pub fn include(&self) -> Option<&[String]> { - match self { - Kernel::Cpu { include, .. } - | Kernel::Cuda { include, .. } - | Kernel::Metal { include, .. } - | Kernel::Rocm { include, .. } - | Kernel::Xpu { include, .. } => include.as_deref(), +impl From for super::General { + fn from(general: General) -> Self { + Self { + name: general.name, + backends: general.backends.into_iter().map(Into::into).collect(), + cuda: general.cuda.map(Into::into), + hub: general.hub.map(Into::into), + python_depends: general.python_depends, + xpu: general.xpu.map(Into::into), } } +} - pub fn backend(&self) -> Backend { - match self { - Kernel::Cpu { .. } => Backend::Cpu, - Kernel::Cuda { .. } => Backend::Cuda, - Kernel::Metal { .. } => Backend::Metal, - Kernel::Rocm { .. } => Backend::Rocm, - Kernel::Xpu { .. } => Backend::Xpu, +impl From for super::CudaGeneral { + fn from(cuda: CudaGeneral) -> Self { + Self { + minver: cuda.minver, + maxver: cuda.maxver, + python_depends: cuda.python_depends, } } +} - pub fn depends(&self) -> &[Dependency] { - match self { - Kernel::Cpu { depends, .. } - | Kernel::Cuda { depends, .. } - | Kernel::Metal { depends, .. } - | Kernel::Rocm { depends, .. } - | Kernel::Xpu { depends, .. } => depends, +impl From for super::XpuGeneral { + fn from(xpu: XpuGeneral) -> Self { + Self { + python_depends: xpu.python_depends, } } +} - pub fn src(&self) -> &[String] { - match self { - Kernel::Cpu { src, .. } - | Kernel::Cuda { src, .. } - | Kernel::Metal { src, .. } - | Kernel::Rocm { src, .. } - | Kernel::Xpu { src, .. } => src, +impl From for super::Hub { + fn from(hub: Hub) -> Self { + Self { + repo_id: hub.repo_id, + branch: hub.branch, } } } -#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] -#[serde(deny_unknown_fields, rename_all = "kebab-case")] -pub enum Backend { - Cpu, - Cuda, - Metal, - Rocm, - Xpu, +impl From for super::Torch { + fn from(torch: Torch) -> Self { + Self { + include: torch.include, + minver: torch.minver, + maxver: torch.maxver, + pyext: torch.pyext, + src: torch.src, + } + } } -impl Display for Backend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Backend::Cpu => write!(f, "cpu"), - Backend::Cuda => write!(f, "cuda"), - Backend::Metal => write!(f, "metal"), - Backend::Rocm => write!(f, "rocm"), - Backend::Xpu => write!(f, "xpu"), +impl From for super::Backend { + fn from(backend: Backend) -> Self { + match backend { + Backend::Cpu => super::Backend::Cpu, + Backend::Cuda => super::Backend::Cuda, + Backend::Metal => super::Backend::Metal, + Backend::Rocm => super::Backend::Rocm, + Backend::Xpu => super::Backend::Xpu, } } } -impl FromStr for Backend { - type Err = String; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "cpu" => Ok(Backend::Cpu), - "cuda" => Ok(Backend::Cuda), - "metal" => Ok(Backend::Metal), - "rocm" => Ok(Backend::Rocm), - "xpu" => Ok(Backend::Xpu), - _ => Err(format!("Unknown backend: {s}")), +impl From for super::Kernel { + fn from(kernel: Kernel) -> Self { + match kernel { + Kernel::Cpu { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cpu { + cxx_flags, + depends, + include, + src, + }, + Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + }, + Kernel::Metal { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Metal { + cxx_flags, + depends, + include, + src, + }, + Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + } => super::Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + }, + Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + } => super::Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + }, } } } -impl TryFrom for Build { - type Error = eyre::Error; - - fn try_from(build: v2::Build) -> Result { - let kernels: HashMap = build - .kernels - .into_iter() - .map(|(k, v)| (k, v.into())) - .collect(); - - let backends = if build.general.universal { - vec![ - Backend::Cpu, - Backend::Cuda, - Backend::Metal, - Backend::Rocm, - Backend::Xpu, - ] - } else { - let backend_set: BTreeSet = - kernels.values().map(|kernel| kernel.backend()).collect(); - backend_set.into_iter().collect() - }; - - Ok(Self { - general: General::from_v2(build.general, backends), +impl From for Build { + fn from(build: super::Build) -> Self { + Self { + general: build.general.into(), torch: build.torch.map(Into::into), - kernels, - }) + kernels: build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), + } } } -impl General { - fn from_v2(general: v2::General, backends: Vec) -> Self { - let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() { - Some(CudaGeneral { - minver: general.cuda_minver, - maxver: general.cuda_maxver, - python_depends: None, - }) - } else { - None - }; - +impl From for General { + fn from(general: super::General) -> Self { Self { name: general.name, - backends, - cuda, + backends: general.backends.into_iter().map(Into::into).collect(), + cuda: general.cuda.map(Into::into), hub: general.hub.map(Into::into), - python_depends: None, - xpu: None, + python_depends: general.python_depends, + xpu: general.xpu.map(Into::into), + } + } +} + +impl From for CudaGeneral { + fn from(cuda: super::CudaGeneral) -> Self { + Self { + minver: cuda.minver, + maxver: cuda.maxver, + python_depends: cuda.python_depends, + } + } +} + +impl From for XpuGeneral { + fn from(xpu: super::XpuGeneral) -> Self { + Self { + python_depends: xpu.python_depends, } } } -impl From for Hub { - fn from(hub: v2::Hub) -> Self { +impl From for Hub { + fn from(hub: super::Hub) -> Self { Self { repo_id: hub.repo_id, branch: hub.branch, @@ -416,8 +329,8 @@ impl From for Hub { } } -impl From for Torch { - fn from(torch: v2::Torch) -> Self { +impl From for Torch { + fn from(torch: super::Torch) -> Self { Self { include: torch.include, minver: torch.minver, @@ -428,10 +341,22 @@ impl From for Torch { } } -impl From for Kernel { - fn from(kernel: v2::Kernel) -> Self { +impl From for Backend { + fn from(backend: super::Backend) -> Self { + match backend { + super::Backend::Cpu => Backend::Cpu, + super::Backend::Cuda => Backend::Cuda, + super::Backend::Metal => Backend::Metal, + super::Backend::Rocm => Backend::Rocm, + super::Backend::Xpu => Backend::Xpu, + } + } +} + +impl From for Kernel { + fn from(kernel: super::Kernel) -> Self { match kernel { - v2::Kernel::Cpu { + super::Kernel::Cpu { cxx_flags, depends, include, @@ -442,7 +367,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Cuda { + super::Kernel::Cuda { cuda_capabilities, cuda_flags, cuda_minver, @@ -459,7 +384,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Metal { + super::Kernel::Metal { cxx_flags, depends, include, @@ -470,7 +395,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Rocm { + super::Kernel::Rocm { cxx_flags, depends, rocm_archs, @@ -485,7 +410,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Xpu { + super::Kernel::Xpu { cxx_flags, depends, sycl_flags, diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index f8e0d0b5..38d2f0d3 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -15,7 +15,7 @@ use torch::{ }; mod config; -use config::{Backend, Build, BuildCompat}; +use config::{v3, Backend, Build, BuildCompat}; mod fileset; use fileset::FileSet; @@ -200,7 +200,8 @@ fn update_build(build_toml: PathBuf) -> Result<()> { let build: Build = build_compat .try_into() .context("Cannot update build configuration")?; - let pretty_toml = toml::to_string_pretty(&build)?; + let v3_build: v3::Build = build.into(); + let pretty_toml = toml::to_string_pretty(&v3_build)?; let mut writer = BufWriter::new(File::create(&build_toml).wrap_err_with(|| {