diff --git a/filesystem/filesystem.go b/filesystem/filesystem.go index 9b5b7e22..40e97fdd 100644 --- a/filesystem/filesystem.go +++ b/filesystem/filesystem.go @@ -148,6 +148,8 @@ var SortDescriptorsByLastMtime = false // "/", meaning that the entire filesystem is mounted, but // it can differ for bind mounts. // ReadOnly - True if this is a read-only mount +// MetadataPath - Absolute path to metadata information, if +// different from mountpoint. // // In order to use a Mount to store fscrypt metadata, some directories must be // setup first. Specifically, the directories created look like: @@ -176,6 +178,7 @@ type Mount struct { DeviceNumber DeviceNumber Subtree string ReadOnly bool + MetadataPath string } // PathSorter allows mounts to be sorted by Path. @@ -211,7 +214,12 @@ func (m *Mount) String() string { // BaseDir returns the path to the base fscrypt directory for this filesystem. func (m *Mount) BaseDir() string { - rawBaseDir := filepath.Join(m.Path, baseDirName) + var rawBaseDir string + if m.MetadataPath != "" { + rawBaseDir = m.MetadataPath + } else { + rawBaseDir = filepath.Join(m.Path, baseDirName) + } // We allow the base directory to be a symlink, but some callers need // the real path, so dereference the symlink here if needed. Since the // directory the symlink points to may not exist yet, we have to read diff --git a/filesystem/mountpoint.go b/filesystem/mountpoint.go index c8307801..be07c36f 100644 --- a/filesystem/mountpoint.go +++ b/filesystem/mountpoint.go @@ -191,7 +191,8 @@ func addUncontainedSubtreesRecursive(dst map[string]bool, // Then, we choose one of these trees which contains (exactly or via path // prefix) *all* mnt.Subtree. We then return the root of this tree. In both // the above examples, this algorithm returns the first Mount. -func findMainMount(filesystemMounts []*Mount) *Mount { +func findMainMount(filesystemMounts []*Mount, lookuppath string) *Mount { + metadataPath := "" // Index this filesystem's mounts by path. Note: paths are unique here, // since non-last mounts were already excluded earlier. // @@ -240,6 +241,14 @@ func findMainMount(filesystemMounts []*Mount) *Mount { uncontainedSubtrees := make(map[string]bool) addUncontainedSubtreesRecursive(uncontainedSubtrees, mntNode, allUncontainedSubtrees) if len(uncontainedSubtrees) != len(allUncontainedSubtrees) { + if mnt.Subtree == "/"+baseDirName && !allSubtrees["/"] { + metadataPath = mnt.Path + } else if len(lookuppath) > 0 && + (mainMount == nil || mainMount.ReadOnly || + (strings.HasPrefix(lookuppath, mnt.Path) && + (len(lookuppath) == len(mnt.Path) || lookuppath[len(mnt.Path)] == '/'))) { + mainMount = mnt + } continue } // If there's more than one eligible mount, they should have the @@ -250,15 +259,25 @@ func findMainMount(filesystemMounts []*Mount) *Mount { return nil } // Prefer a read-write mount to a read-only one. - if mainMount == nil || mainMount.ReadOnly { + if filepath.Base(mnt.Path) != baseDirName && + (mainMount == nil || mainMount.ReadOnly || + (len(lookuppath) > 0 && strings.HasPrefix(lookuppath, mnt.Path) && + (len(lookuppath) == len(mnt.Path) || lookuppath[len(mnt.Path)] == '/'))) { mainMount = mnt } + + if filepath.Base(mnt.Path) == baseDirName { + metadataPath = mnt.Path + } + } + if mainMount != nil { + mainMount.MetadataPath = metadataPath } return mainMount } // This is separate from loadMountInfo() only for unit testing. -func readMountInfo(r io.Reader) error { +func readMountInfo(r io.Reader, path string) error { mountsByPath := make(map[string]*Mount) mountsByDevice = make(map[DeviceNumber]*Mount) @@ -292,21 +311,21 @@ func readMountInfo(r io.Reader) error { append(allMountsByDevice[mnt.DeviceNumber], mnt) } for deviceNumber, filesystemMounts := range allMountsByDevice { - mountsByDevice[deviceNumber] = findMainMount(filesystemMounts) + mountsByDevice[deviceNumber] = findMainMount(filesystemMounts, path) } return nil } // loadMountInfo populates the Mount mappings by parsing /proc/self/mountinfo. // It returns an error if the Mount mappings cannot be populated. -func loadMountInfo() error { +func loadMountInfo(path string) error { if !mountsInitialized { file, err := os.Open("/proc/self/mountinfo") if err != nil { return err } defer file.Close() - if err := readMountInfo(file); err != nil { + if err := readMountInfo(file, path); err != nil { return err } mountsInitialized = true @@ -324,7 +343,7 @@ func filesystemLacksMainMountError(deviceNumber DeviceNumber) error { func AllFilesystems() ([]*Mount, error) { mountMutex.Lock() defer mountMutex.Unlock() - if err := loadMountInfo(); err != nil { + if err := loadMountInfo(""); err != nil { return nil, err } @@ -345,7 +364,7 @@ func UpdateMountInfo() error { mountMutex.Lock() defer mountMutex.Unlock() mountsInitialized = false - return loadMountInfo() + return loadMountInfo("") } // FindMount returns the main Mount object for the filesystem which contains the @@ -355,7 +374,7 @@ func UpdateMountInfo() error { func FindMount(path string) (*Mount, error) { mountMutex.Lock() defer mountMutex.Unlock() - if err := loadMountInfo(); err != nil { + if err := loadMountInfo(path); err != nil { return nil, err } deviceNumber, err := getNumberOfContainingDevice(path) @@ -431,7 +450,7 @@ func getMountFromLink(link string) (*Mount, error) { // Lookup mountpoints for device in global store mountMutex.Lock() defer mountMutex.Unlock() - if err := loadMountInfo(); err != nil { + if err := loadMountInfo(searchPath); err != nil { return nil, err } mnt, ok := mountsByDevice[deviceNumber] diff --git a/filesystem/mountpoint_test.go b/filesystem/mountpoint_test.go index 633ff947..6cf2e880 100644 --- a/filesystem/mountpoint_test.go +++ b/filesystem/mountpoint_test.go @@ -52,7 +52,11 @@ func endLoadMountInfoTest() { } func loadMountInfoFromString(str string) { - readMountInfo(strings.NewReader(str)) + readMountInfo(strings.NewReader(str), "") +} + +func loadMountInfoFromStringWithPath(str string, path string) { + readMountInfo(strings.NewReader(str), path) } func mountForDevice(deviceNumberStr string) *Mount { @@ -373,6 +377,87 @@ func TestLoadAmbiguousMounts(t *testing.T) { } } +// Test when the .fscrypt directory is mounted directly +func TestLoadMetadataDir(t *testing.T) { + tempDir, err := ioutil.TempDir("", "fscrypt") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + tempDir, err = filepath.Abs(tempDir) + if err != nil { + t.Fatal(err) + } + if err := os.Mkdir(tempDir+"/.fscrypt", 0700); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(tempDir+"/home", 0700); err != nil { + t.Fatal(err) + } + if err := os.Mkdir(tempDir+"/home2", 0700); err != nil { + t.Fatal(err) + } + mountinfo := fmt.Sprintf(` +222 15 259:3 /.fscrypt %s rw shared:1 - ext4 /dev/root rw +222 15 259:3 /home %s rw shared:1 - ext4 /dev/root rw +222 15 259:3 /foo %s rw shared:1 - ext4 /dev/root rw +`, tempDir+"/.fscrypt", tempDir+"/home", tempDir+"/home2") + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromStringWithPath(mountinfo, "/home/userA") + mnt := mountForDevice("259:3") + if mnt == nil { + t.Fatal("Failed to load mount") + } + if mnt.Path != tempDir+"/home" { + t.Error("Wrong path") + } + if mnt.MetadataPath != tempDir+"/.fscrypt" { + t.Error("Wrong metadata path") + } +} + +// Test when the .fscrypt directory is mounted directly but with a subtree +// equal to '/'. In this case only the mount path can help. +func TestLoadMetadataDirNoSubtree(t *testing.T) { + tempDir, err := ioutil.TempDir("", "fscrypt") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + tempDir, err = filepath.Abs(tempDir) + if err != nil { + t.Fatal(err) + } + if err := os.Mkdir(tempDir+"/.fscrypt", 0700); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(tempDir+"/home", 0700); err != nil { + t.Fatal(err) + } + if err := os.Mkdir(tempDir+"/home2", 0700); err != nil { + t.Fatal(err) + } + mountinfo := fmt.Sprintf(` +222 15 259:3 / %s rw shared:1 - ext4 /dev/root rw +222 15 259:3 / %s rw shared:1 - ext4 /dev/root rw +222 15 259:3 / %s rw shared:1 - ext4 /dev/root rw +`, tempDir+"/.fscrypt", tempDir+"/home", tempDir+"/home2") + beginLoadMountInfoTest() + defer endLoadMountInfoTest() + loadMountInfoFromStringWithPath(mountinfo, "/home/userA") + mnt := mountForDevice("259:3") + if mnt == nil { + t.Fatal("Failed to load mount") + } + if mnt.Path != tempDir+"/home" { + t.Error("Wrong path") + } + if mnt.MetadataPath != tempDir+"/.fscrypt" { + t.Error("Wrong metadata path") + } +} + // Test making a filesystem link (i.e. "UUID=...") and following it, and test // that leading and trailing whitespace in the link is ignored. func TestGetMountFromLink(t *testing.T) {