diff --git a/build2cmake/src/config/common.rs b/build2cmake/src/config/common.rs deleted file mode 100644 index 7b0eeced..00000000 --- a/build2cmake/src/config/common.rs +++ /dev/null @@ -1,24 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -#[non_exhaustive] -#[serde(rename_all = "lowercase")] -pub enum Dependency { - #[serde(rename = "cutlass_2_10")] - Cutlass2_10, - #[serde(rename = "cutlass_3_5")] - Cutlass3_5, - #[serde(rename = "cutlass_3_6")] - Cutlass3_6, - #[serde(rename = "cutlass_3_8")] - Cutlass3_8, - #[serde(rename = "cutlass_3_9")] - Cutlass3_9, - #[serde(rename = "cutlass_4_0")] - Cutlass4_0, - #[serde(rename = "cutlass_sycl")] - CutlassSycl, - #[serde(rename = "metal-cpp")] - MetalCpp, - Torch, -} diff --git a/build2cmake/src/config/compat.rs b/build2cmake/src/config/compat.rs new file mode 100644 index 00000000..c8553a4a --- /dev/null +++ b/build2cmake/src/config/compat.rs @@ -0,0 +1,39 @@ +use eyre::Result; +use serde::Deserialize; +use serde_value::Value; + +use super::{v1, v2, v3, Build}; + +#[derive(Debug)] +pub enum BuildCompat { + V1(v1::Build), + V2(v2::Build), + V3(v3::Build), +} + +impl<'de> Deserialize<'de> for BuildCompat { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + + v1::Build::deserialize(value.clone()) + .map(BuildCompat::V1) + .or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2)) + .or_else(|_| v3::Build::deserialize(value.clone()).map(BuildCompat::V3)) + .map_err(serde::de::Error::custom) + } +} + +impl TryFrom for Build { + type Error = eyre::Error; + + fn try_from(compat: BuildCompat) -> Result { + match compat { + BuildCompat::V1(v1_build) => v1_build.try_into(), + BuildCompat::V2(v2_build) => v2_build.try_into(), + BuildCompat::V3(v3_build) => Ok(v3_build.into()), + } + } +} diff --git a/build2cmake/src/config/deps.rs b/build2cmake/src/config/deps.rs new file mode 100644 index 00000000..445d8cf1 --- /dev/null +++ b/build2cmake/src/config/deps.rs @@ -0,0 +1,89 @@ +use std::{collections::HashMap, sync::LazyLock}; + +use eyre::Result; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use super::Backend; + +pub static PYTHON_DEPENDENCIES: LazyLock = + LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap()); + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +#[non_exhaustive] +#[serde(rename_all = "lowercase")] +pub enum Dependency { + #[serde(rename = "cutlass_2_10")] + Cutlass2_10, + #[serde(rename = "cutlass_3_5")] + Cutlass3_5, + #[serde(rename = "cutlass_3_6")] + Cutlass3_6, + #[serde(rename = "cutlass_3_8")] + Cutlass3_8, + #[serde(rename = "cutlass_3_9")] + Cutlass3_9, + #[serde(rename = "cutlass_4_0")] + Cutlass4_0, + #[serde(rename = "cutlass_sycl")] + CutlassSycl, + #[serde(rename = "metal-cpp")] + MetalCpp, + Torch, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct PythonDependencies { + general: HashMap, + backends: HashMap>, +} + +impl PythonDependencies { + pub fn get_dependency(&self, dependency: &str) -> Result<&[String], DependencyError> { + match self.general.get(dependency) { + None => Err(DependencyError::GeneralDependency { + dependency: dependency.to_string(), + }), + Some(dep) => Ok(&dep.python), + } + } + + pub fn get_backend_dependency( + &self, + backend: Backend, + dependency: &str, + ) -> Result<&[String], DependencyError> { + let backend_deps = match self.backends.get(&backend) { + None => { + return Err(DependencyError::Backend { + backend: backend.to_string(), + }) + } + Some(backend_deps) => backend_deps, + }; + match backend_deps.get(dependency) { + None => Err(DependencyError::Dependency { + backend: backend.to_string(), + dependency: dependency.to_string(), + }), + Some(dep) => Ok(&dep.python), + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +struct PythonDependency { + nix: Vec, + python: Vec, +} + +#[derive(Debug, Error)] +pub enum DependencyError { + #[error("No dependencies are defined for backend: {backend:?}")] + Backend { backend: String }, + #[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")] + Dependency { backend: String, dependency: String }, + #[error("Unknown dependency: `{dependency:?}`")] + GeneralDependency { dependency: String }, +} diff --git a/build2cmake/src/config/mod.rs b/build2cmake/src/config/mod.rs index f9dc234d..816b4fea 100644 --- a/build2cmake/src/config/mod.rs +++ b/build2cmake/src/config/mod.rs @@ -1,50 +1,270 @@ +use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr}; + use eyre::Result; -use serde::Deserialize; -use serde_value::Value; +use serde::{Deserialize, Serialize}; -pub mod v1; +mod deps; +pub use deps::Dependency; -mod common; +mod compat; +pub use compat::BuildCompat; +mod v1; mod v2; +pub(crate) mod v3; + +use itertools::Itertools; -mod v3; -pub use common::Dependency; -pub use v3::{Backend, Build, General, Kernel, Torch}; +use crate::version::Version; -#[derive(Debug)] -pub enum BuildCompat { - V1(v1::Build), - V2(v2::Build), - V3(Build), +pub struct Build { + pub general: General, + pub kernels: HashMap, + pub torch: Option, } -impl<'de> Deserialize<'de> for BuildCompat { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; +impl Build { + pub fn is_noarch(&self) -> bool { + self.kernels.is_empty() + } - v1::Build::deserialize(value.clone()) - .map(BuildCompat::V1) - .or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2)) - .or_else(|_| Build::deserialize(value.clone()).map(BuildCompat::V3)) - .map_err(serde::de::Error::custom) + pub fn supports_backend(&self, backend: &Backend) -> bool { + self.general.backends.contains(backend) } } -impl TryFrom for Build { - type Error = eyre::Error; +pub struct General { + pub name: String, + pub backends: Vec, + pub hub: Option, + pub python_depends: Option>, + + pub cuda: Option, + pub xpu: Option, +} + +impl General { + /// Name of the kernel as a Python extension. + pub fn python_name(&self) -> String { + self.name.replace("-", "_") + } + + pub fn python_depends(&self) -> Box> + '_> { + let general_python_deps = match self.python_depends.as_ref() { + Some(deps) => deps, + None => { + return Box::new(std::iter::empty()); + } + }; + + Box::new(general_python_deps.iter().flat_map(move |dep| { + match deps::PYTHON_DEPENDENCIES.get_dependency(dep) { + Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), + Err(e) => vec![Err(e.into())], + } + })) + } + + pub fn backend_python_depends( + &self, + backend: Backend, + ) -> Box> + '_> { + let backend_python_deps = match backend { + Backend::Cuda => self + .cuda + .as_ref() + .and_then(|cuda| cuda.python_depends.as_ref()), + Backend::Xpu => self + .xpu + .as_ref() + .and_then(|xpu| xpu.python_depends.as_ref()), + _ => None, + }; + + let backend_python_deps = match backend_python_deps { + Some(deps) => deps, + None => { + return Box::new(std::iter::empty()); + } + }; - fn try_from(compat: BuildCompat) -> Result { - match compat { - BuildCompat::V1(v1_build) => { - let v2_build: v2::Build = v1_build.try_into()?; - v2_build.try_into() + Box::new(backend_python_deps.iter().flat_map(move |dep| { + match deps::PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) { + Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), + Err(e) => vec![Err(e.into())], } - BuildCompat::V2(v2_build) => v2_build.try_into(), - BuildCompat::V3(v3_build) => Ok(v3_build), + })) + } +} + +pub struct CudaGeneral { + pub minver: Option, + pub maxver: Option, + pub python_depends: Option>, +} + +pub struct XpuGeneral { + pub python_depends: Option>, +} + +pub struct Hub { + pub repo_id: Option, + pub branch: Option, +} + +pub struct Torch { + pub include: Option>, + pub minver: Option, + pub maxver: Option, + pub pyext: Option>, + pub src: Vec, +} + +impl Torch { + pub fn data_globs(&self) -> Option> { + match self.pyext.as_ref() { + Some(exts) => { + let globs = exts + .iter() + .filter(|&ext| ext != "py" && ext != "pyi") + .map(|ext| format!("\"**/*.{ext}\"")) + .collect_vec(); + if globs.is_empty() { + None + } else { + Some(globs) + } + } + + None => None, + } + } +} + +pub enum Kernel { + Cpu { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + Cuda { + cuda_capabilities: Option>, + cuda_flags: Option>, + cuda_minver: Option, + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + Metal { + cxx_flags: Option>, + depends: Vec, + include: Option>, + src: Vec, + }, + Rocm { + cxx_flags: Option>, + depends: Vec, + rocm_archs: Option>, + hip_flags: Option>, + include: Option>, + src: Vec, + }, + Xpu { + cxx_flags: Option>, + depends: Vec, + sycl_flags: Option>, + include: Option>, + src: Vec, + }, +} + +impl Kernel { + pub fn cxx_flags(&self) -> Option<&[String]> { + match self { + Kernel::Cpu { cxx_flags, .. } + | Kernel::Cuda { cxx_flags, .. } + | Kernel::Metal { cxx_flags, .. } + | Kernel::Rocm { cxx_flags, .. } + | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), + } + } + + pub fn include(&self) -> Option<&[String]> { + match self { + Kernel::Cpu { include, .. } + | Kernel::Cuda { include, .. } + | Kernel::Metal { include, .. } + | Kernel::Rocm { include, .. } + | Kernel::Xpu { include, .. } => include.as_deref(), + } + } + + pub fn backend(&self) -> Backend { + match self { + Kernel::Cpu { .. } => Backend::Cpu, + Kernel::Cuda { .. } => Backend::Cuda, + Kernel::Metal { .. } => Backend::Metal, + Kernel::Rocm { .. } => Backend::Rocm, + Kernel::Xpu { .. } => Backend::Xpu, + } + } + + pub fn depends(&self) -> &[Dependency] { + match self { + Kernel::Cpu { depends, .. } + | Kernel::Cuda { depends, .. } + | Kernel::Metal { depends, .. } + | Kernel::Rocm { depends, .. } + | Kernel::Xpu { depends, .. } => depends, + } + } + + pub fn src(&self) -> &[String] { + match self { + Kernel::Cpu { src, .. } + | Kernel::Cuda { src, .. } + | Kernel::Metal { src, .. } + | Kernel::Rocm { src, .. } + | Kernel::Xpu { src, .. } => src, + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum Backend { + Cpu, + Cuda, + Metal, + Rocm, + Xpu, +} + +impl Display for Backend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Backend::Cpu => write!(f, "cpu"), + Backend::Cuda => write!(f, "cuda"), + Backend::Metal => write!(f, "metal"), + Backend::Rocm => write!(f, "rocm"), + Backend::Xpu => write!(f, "xpu"), + } + } +} + +impl FromStr for Backend { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "cpu" => Ok(Backend::Cpu), + "cuda" => Ok(Backend::Cuda), + "metal" => Ok(Backend::Metal), + "rocm" => Ok(Backend::Rocm), + "xpu" => Ok(Backend::Xpu), + _ => Err(format!("Unknown backend: {s}")), } } } diff --git a/build2cmake/src/config/v1.rs b/build2cmake/src/config/v1.rs index 5582d7ff..608f35d5 100644 --- a/build2cmake/src/config/v1.rs +++ b/build2cmake/src/config/v1.rs @@ -1,8 +1,13 @@ -use std::{collections::HashMap, fmt::Display, path::PathBuf}; +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, + path::PathBuf, +}; +use eyre::{bail, Result}; use serde::Deserialize; -use super::common::Dependency; +use super::{Backend, Dependency}; #[derive(Debug, Deserialize)] #[serde(deny_unknown_fields)] @@ -63,3 +68,97 @@ impl Display for Language { } } } + +impl TryFrom for super::Build { + type Error = eyre::Error; + + fn try_from(build: Build) -> Result { + let universal = build + .torch + .as_ref() + .map(|torch| torch.universal) + .unwrap_or(false); + + let kernels = convert_kernels(build.kernels)?; + + let backends = if universal { + vec![ + Backend::Cpu, + Backend::Cuda, + Backend::Metal, + Backend::Rocm, + Backend::Xpu, + ] + } else { + let backend_set: BTreeSet = + kernels.values().map(|kernel| kernel.backend()).collect(); + backend_set.into_iter().collect() + }; + + Ok(Self { + general: super::General { + name: build.general.name, + backends, + hub: None, + python_depends: None, + cuda: None, + xpu: None, + }, + torch: build.torch.map(Into::into), + kernels, + }) + } +} + +fn convert_kernels(v1_kernels: HashMap) -> Result> { + let mut kernels = HashMap::new(); + + for (name, kernel) in v1_kernels { + if kernel.language == Language::CudaHipify { + // We need to add an affix to avoid conflict with the CUDA kernel. + let rocm_name = format!("{name}_rocm"); + if kernels.contains_key(&rocm_name) { + bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`") + } + + kernels.insert( + format!("{name}_rocm"), + super::Kernel::Rocm { + cxx_flags: None, + rocm_archs: kernel.rocm_archs, + hip_flags: None, + depends: kernel.depends.clone(), + include: kernel.include.clone(), + src: kernel.src.clone(), + }, + ); + } + + kernels.insert( + name, + super::Kernel::Cuda { + cuda_capabilities: kernel.cuda_capabilities, + cuda_flags: None, + cuda_minver: None, + cxx_flags: None, + depends: kernel.depends, + include: kernel.include, + src: kernel.src, + }, + ); + } + + Ok(kernels) +} + +impl From for super::Torch { + fn from(torch: Torch) -> Self { + Self { + include: torch.include, + minver: None, + maxver: None, + pyext: torch.pyext, + src: torch.src, + } + } +} diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index 20b906bc..43e22117 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -1,15 +1,15 @@ -use std::{collections::HashMap, fmt::Display, path::PathBuf}; +use std::{ + collections::{BTreeSet, HashMap}, + fmt::Display, + path::PathBuf, +}; -use eyre::{bail, Result}; +use eyre::Result; use serde::{Deserialize, Serialize}; +use super::{Backend, Dependency}; use crate::version::Version; -use super::{ - common::Dependency, - v1::{self, Language}, -}; - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct Build { @@ -117,85 +117,152 @@ pub enum Kernel { }, } -impl TryFrom for Build { +impl TryFrom for super::Build { type Error = eyre::Error; - fn try_from(build: v1::Build) -> Result { - let universal = build - .torch - .as_ref() - .map(|torch| torch.universal) - .unwrap_or(false); + fn try_from(build: Build) -> Result { + let kernels: HashMap = build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(); + + let backends = if build.general.universal { + vec![ + Backend::Cpu, + Backend::Cuda, + Backend::Metal, + Backend::Rocm, + Backend::Xpu, + ] + } else { + let backend_set: BTreeSet = + kernels.values().map(|kernel| kernel.backend()).collect(); + backend_set.into_iter().collect() + }; + Ok(Self { - general: General::from(build.general, universal), + general: General::from_v2(build.general, backends), torch: build.torch.map(Into::into), - kernels: convert_kernels(build.kernels)?, + kernels, }) } } impl General { - fn from(general: v1::General, universal: bool) -> Self { - Self { + fn from_v2(general: General, backends: Vec) -> super::General { + let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() { + Some(super::CudaGeneral { + minver: general.cuda_minver, + maxver: general.cuda_maxver, + python_depends: None, + }) + } else { + None + }; + + super::General { name: general.name, - universal, - cuda_maxver: None, - cuda_minver: None, - hub: None, + backends, + cuda, + hub: general.hub.map(Into::into), python_depends: None, + xpu: None, } } } -fn convert_kernels(v1_kernels: HashMap) -> Result> { - let mut kernels = HashMap::new(); - - for (name, kernel) in v1_kernels { - if kernel.language == Language::CudaHipify { - // We need to add an affix to avoid conflict with the CUDA kernel. - let rocm_name = format!("{name}_rocm"); - if kernels.contains_key(&rocm_name) { - bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`") - } - - kernels.insert( - format!("{name}_rocm"), - Kernel::Rocm { - cxx_flags: None, - rocm_archs: kernel.rocm_archs, - hip_flags: None, - depends: kernel.depends.clone(), - include: kernel.include.clone(), - src: kernel.src.clone(), - }, - ); +impl From for super::Hub { + fn from(hub: Hub) -> Self { + Self { + repo_id: hub.repo_id, + branch: hub.branch, } - - kernels.insert( - name, - Kernel::Cuda { - cuda_capabilities: kernel.cuda_capabilities, - cuda_flags: None, - cuda_minver: None, - cxx_flags: None, - depends: kernel.depends, - include: kernel.include, - src: kernel.src, - }, - ); } - - Ok(kernels) } -impl From for Torch { - fn from(torch: v1::Torch) -> Self { +impl From for super::Torch { + fn from(torch: Torch) -> Self { Self { include: torch.include, - minver: None, - maxver: None, + minver: torch.minver, + maxver: torch.maxver, pyext: torch.pyext, src: torch.src, } } } + +impl From for super::Kernel { + fn from(kernel: Kernel) -> Self { + match kernel { + Kernel::Cpu { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cpu { + cxx_flags, + depends, + include, + src, + }, + Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + }, + Kernel::Metal { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Metal { + cxx_flags, + depends, + include, + src, + }, + Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + } => super::Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + }, + Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + } => super::Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + }, + } + } +} diff --git a/build2cmake/src/config/v3.rs b/build2cmake/src/config/v3.rs index 14c25418..770ec05a 100644 --- a/build2cmake/src/config/v3.rs +++ b/build2cmake/src/config/v3.rs @@ -1,78 +1,11 @@ -use std::{ - collections::{BTreeSet, HashMap}, - fmt::Display, - path::PathBuf, - str::FromStr, - sync::LazyLock, -}; - -use eyre::Result; -use itertools::Itertools; +use std::collections::HashMap; +use std::path::PathBuf; + use serde::{Deserialize, Serialize}; -use thiserror::Error; -use super::{common::Dependency, v2}; +use super::Dependency; use crate::version::Version; -#[derive(Debug, Error)] -enum DependencyError { - #[error("No dependencies are defined for backend: {backend:?}")] - Backend { backend: String }, - #[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")] - Dependency { backend: String, dependency: String }, - #[error("Unknown dependency: `{dependency:?}`")] - GeneralDependency { dependency: String }, -} - -#[derive(Debug, Deserialize, Serialize)] -#[serde(deny_unknown_fields)] -struct PythonDependencies { - general: HashMap, - backends: HashMap>, -} - -impl PythonDependencies { - fn get_dependency(&self, dependency: &str) -> Result<&[String], DependencyError> { - match self.general.get(dependency) { - None => Err(DependencyError::GeneralDependency { - dependency: dependency.to_string(), - }), - Some(dep) => Ok(&dep.python), - } - } - - fn get_backend_dependency( - &self, - backend: Backend, - dependency: &str, - ) -> Result<&[String], DependencyError> { - let backend_deps = match self.backends.get(&backend) { - None => { - return Err(DependencyError::Backend { - backend: backend.to_string(), - }) - } - Some(backend_deps) => backend_deps, - }; - match backend_deps.get(dependency) { - None => Err(DependencyError::Dependency { - backend: backend.to_string(), - dependency: dependency.to_string(), - }), - Some(dep) => Ok(&dep.python), - } - } -} - -#[derive(Debug, Deserialize, Serialize)] -struct PythonDependency { - nix: Vec, - python: Vec, -} - -static PYTHON_DEPENDENCIES: LazyLock = - LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap()); - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct Build { @@ -83,16 +16,6 @@ pub struct Build { pub kernels: HashMap, } -impl Build { - pub fn is_noarch(&self) -> bool { - self.kernels.is_empty() - } - - pub fn supports_backend(&self, backend: &Backend) -> bool { - self.general.backends.contains(backend) - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct General { @@ -109,60 +32,6 @@ pub struct General { pub xpu: Option, } -impl General { - /// Name of the kernel as a Python extension. - pub fn python_name(&self) -> String { - self.name.replace("-", "_") - } - - pub fn python_depends(&self) -> Box> + '_> { - let general_python_deps = match self.python_depends.as_ref() { - Some(deps) => deps, - None => { - return Box::new(std::iter::empty()); - } - }; - - Box::new(general_python_deps.iter().flat_map(move |dep| { - match PYTHON_DEPENDENCIES.get_dependency(dep) { - Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), - Err(e) => vec![Err(e.into())], - } - })) - } - - pub fn backend_python_depends( - &self, - backend: Backend, - ) -> Box> + '_> { - let backend_python_deps = match backend { - Backend::Cuda => self - .cuda - .as_ref() - .and_then(|cuda| cuda.python_depends.as_ref()), - Backend::Xpu => self - .xpu - .as_ref() - .and_then(|xpu| xpu.python_depends.as_ref()), - _ => None, - }; - - let backend_python_deps = match backend_python_deps { - Some(deps) => deps, - None => { - return Box::new(std::iter::empty()); - } - }; - - Box::new(backend_python_deps.iter().flat_map(move |dep| { - match PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) { - Ok(deps) => deps.iter().map(|s| Ok(s.clone())).collect::>(), - Err(e) => vec![Err(e.into())], - } - })) - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case")] pub struct CudaGeneral { @@ -196,27 +65,6 @@ pub struct Torch { pub src: Vec, } -impl Torch { - pub fn data_globs(&self) -> Option> { - match self.pyext.as_ref() { - Some(exts) => { - let globs = exts - .iter() - .filter(|&ext| ext != "py" && ext != "pyi") - .map(|ext| format!("\"**/*.{ext}\"")) - .collect_vec(); - if globs.is_empty() { - None - } else { - Some(globs) - } - } - - None => None, - } - } -} - #[derive(Debug, Deserialize, Serialize)] #[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")] pub enum Kernel { @@ -263,152 +111,217 @@ pub enum Kernel { }, } -impl Kernel { - pub fn cxx_flags(&self) -> Option<&[String]> { - match self { - Kernel::Cpu { cxx_flags, .. } - | Kernel::Cuda { cxx_flags, .. } - | Kernel::Metal { cxx_flags, .. } - | Kernel::Rocm { cxx_flags, .. } - | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(), +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] +#[serde(deny_unknown_fields, rename_all = "kebab-case")] +pub enum Backend { + Cpu, + Cuda, + Metal, + Rocm, + Xpu, +} + +impl From for super::Build { + fn from(build: Build) -> Self { + let kernels: HashMap = build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(); + + Self { + general: build.general.into(), + torch: build.torch.map(Into::into), + kernels, } } +} - pub fn include(&self) -> Option<&[String]> { - match self { - Kernel::Cpu { include, .. } - | Kernel::Cuda { include, .. } - | Kernel::Metal { include, .. } - | Kernel::Rocm { include, .. } - | Kernel::Xpu { include, .. } => include.as_deref(), +impl From for super::General { + fn from(general: General) -> Self { + Self { + name: general.name, + backends: general.backends.into_iter().map(Into::into).collect(), + cuda: general.cuda.map(Into::into), + hub: general.hub.map(Into::into), + python_depends: general.python_depends, + xpu: general.xpu.map(Into::into), } } +} - pub fn backend(&self) -> Backend { - match self { - Kernel::Cpu { .. } => Backend::Cpu, - Kernel::Cuda { .. } => Backend::Cuda, - Kernel::Metal { .. } => Backend::Metal, - Kernel::Rocm { .. } => Backend::Rocm, - Kernel::Xpu { .. } => Backend::Xpu, +impl From for super::CudaGeneral { + fn from(cuda: CudaGeneral) -> Self { + Self { + minver: cuda.minver, + maxver: cuda.maxver, + python_depends: cuda.python_depends, } } +} - pub fn depends(&self) -> &[Dependency] { - match self { - Kernel::Cpu { depends, .. } - | Kernel::Cuda { depends, .. } - | Kernel::Metal { depends, .. } - | Kernel::Rocm { depends, .. } - | Kernel::Xpu { depends, .. } => depends, +impl From for super::XpuGeneral { + fn from(xpu: XpuGeneral) -> Self { + Self { + python_depends: xpu.python_depends, } } +} - pub fn src(&self) -> &[String] { - match self { - Kernel::Cpu { src, .. } - | Kernel::Cuda { src, .. } - | Kernel::Metal { src, .. } - | Kernel::Rocm { src, .. } - | Kernel::Xpu { src, .. } => src, +impl From for super::Hub { + fn from(hub: Hub) -> Self { + Self { + repo_id: hub.repo_id, + branch: hub.branch, } } } -#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)] -#[serde(deny_unknown_fields, rename_all = "kebab-case")] -pub enum Backend { - Cpu, - Cuda, - Metal, - Rocm, - Xpu, +impl From for super::Torch { + fn from(torch: Torch) -> Self { + Self { + include: torch.include, + minver: torch.minver, + maxver: torch.maxver, + pyext: torch.pyext, + src: torch.src, + } + } } -impl Display for Backend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Backend::Cpu => write!(f, "cpu"), - Backend::Cuda => write!(f, "cuda"), - Backend::Metal => write!(f, "metal"), - Backend::Rocm => write!(f, "rocm"), - Backend::Xpu => write!(f, "xpu"), +impl From for super::Backend { + fn from(backend: Backend) -> Self { + match backend { + Backend::Cpu => super::Backend::Cpu, + Backend::Cuda => super::Backend::Cuda, + Backend::Metal => super::Backend::Metal, + Backend::Rocm => super::Backend::Rocm, + Backend::Xpu => super::Backend::Xpu, } } } -impl FromStr for Backend { - type Err = String; - - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "cpu" => Ok(Backend::Cpu), - "cuda" => Ok(Backend::Cuda), - "metal" => Ok(Backend::Metal), - "rocm" => Ok(Backend::Rocm), - "xpu" => Ok(Backend::Xpu), - _ => Err(format!("Unknown backend: {s}")), +impl From for super::Kernel { + fn from(kernel: Kernel) -> Self { + match kernel { + Kernel::Cpu { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cpu { + cxx_flags, + depends, + include, + src, + }, + Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + } => super::Kernel::Cuda { + cuda_capabilities, + cuda_flags, + cuda_minver, + cxx_flags, + depends, + include, + src, + }, + Kernel::Metal { + cxx_flags, + depends, + include, + src, + } => super::Kernel::Metal { + cxx_flags, + depends, + include, + src, + }, + Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + } => super::Kernel::Rocm { + cxx_flags, + depends, + rocm_archs, + hip_flags, + include, + src, + }, + Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + } => super::Kernel::Xpu { + cxx_flags, + depends, + sycl_flags, + include, + src, + }, } } } -impl TryFrom for Build { - type Error = eyre::Error; - - fn try_from(build: v2::Build) -> Result { - let kernels: HashMap = build - .kernels - .into_iter() - .map(|(k, v)| (k, v.into())) - .collect(); - - let backends = if build.general.universal { - vec![ - Backend::Cpu, - Backend::Cuda, - Backend::Metal, - Backend::Rocm, - Backend::Xpu, - ] - } else { - let backend_set: BTreeSet = - kernels.values().map(|kernel| kernel.backend()).collect(); - backend_set.into_iter().collect() - }; - - Ok(Self { - general: General::from_v2(build.general, backends), +impl From for Build { + fn from(build: super::Build) -> Self { + Self { + general: build.general.into(), torch: build.torch.map(Into::into), - kernels, - }) + kernels: build + .kernels + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(), + } } } -impl General { - fn from_v2(general: v2::General, backends: Vec) -> Self { - let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() { - Some(CudaGeneral { - minver: general.cuda_minver, - maxver: general.cuda_maxver, - python_depends: None, - }) - } else { - None - }; - +impl From for General { + fn from(general: super::General) -> Self { Self { name: general.name, - backends, - cuda, + backends: general.backends.into_iter().map(Into::into).collect(), + cuda: general.cuda.map(Into::into), hub: general.hub.map(Into::into), - python_depends: None, - xpu: None, + python_depends: general.python_depends, + xpu: general.xpu.map(Into::into), + } + } +} + +impl From for CudaGeneral { + fn from(cuda: super::CudaGeneral) -> Self { + Self { + minver: cuda.minver, + maxver: cuda.maxver, + python_depends: cuda.python_depends, + } + } +} + +impl From for XpuGeneral { + fn from(xpu: super::XpuGeneral) -> Self { + Self { + python_depends: xpu.python_depends, } } } -impl From for Hub { - fn from(hub: v2::Hub) -> Self { +impl From for Hub { + fn from(hub: super::Hub) -> Self { Self { repo_id: hub.repo_id, branch: hub.branch, @@ -416,8 +329,8 @@ impl From for Hub { } } -impl From for Torch { - fn from(torch: v2::Torch) -> Self { +impl From for Torch { + fn from(torch: super::Torch) -> Self { Self { include: torch.include, minver: torch.minver, @@ -428,10 +341,22 @@ impl From for Torch { } } -impl From for Kernel { - fn from(kernel: v2::Kernel) -> Self { +impl From for Backend { + fn from(backend: super::Backend) -> Self { + match backend { + super::Backend::Cpu => Backend::Cpu, + super::Backend::Cuda => Backend::Cuda, + super::Backend::Metal => Backend::Metal, + super::Backend::Rocm => Backend::Rocm, + super::Backend::Xpu => Backend::Xpu, + } + } +} + +impl From for Kernel { + fn from(kernel: super::Kernel) -> Self { match kernel { - v2::Kernel::Cpu { + super::Kernel::Cpu { cxx_flags, depends, include, @@ -442,7 +367,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Cuda { + super::Kernel::Cuda { cuda_capabilities, cuda_flags, cuda_minver, @@ -459,7 +384,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Metal { + super::Kernel::Metal { cxx_flags, depends, include, @@ -470,7 +395,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Rocm { + super::Kernel::Rocm { cxx_flags, depends, rocm_archs, @@ -485,7 +410,7 @@ impl From for Kernel { include, src, }, - v2::Kernel::Xpu { + super::Kernel::Xpu { cxx_flags, depends, sycl_flags, diff --git a/build2cmake/src/main.rs b/build2cmake/src/main.rs index f8e0d0b5..38d2f0d3 100644 --- a/build2cmake/src/main.rs +++ b/build2cmake/src/main.rs @@ -15,7 +15,7 @@ use torch::{ }; mod config; -use config::{Backend, Build, BuildCompat}; +use config::{v3, Backend, Build, BuildCompat}; mod fileset; use fileset::FileSet; @@ -200,7 +200,8 @@ fn update_build(build_toml: PathBuf) -> Result<()> { let build: Build = build_compat .try_into() .context("Cannot update build configuration")?; - let pretty_toml = toml::to_string_pretty(&build)?; + let v3_build: v3::Build = build.into(); + let pretty_toml = toml::to_string_pretty(&v3_build)?; let mut writer = BufWriter::new(File::create(&build_toml).wrap_err_with(|| {